forked from ZrrSkywalker/Personalize-SAM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
persam_f.py
324 lines (253 loc) · 11.3 KB
/
persam_f.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import os
import cv2
from tqdm import tqdm
import argparse
import matplotlib.pyplot as plt
import warnings
warnings.filterwarnings('ignore')
from show import *
from per_segment_anything import sam_model_registry, SamPredictor
def get_arguments():
parser = argparse.ArgumentParser()
parser.add_argument('--data', type=str, default='./data')
parser.add_argument('--outdir', type=str, default='persam_f')
parser.add_argument('--ckpt', type=str, default='./sam_vit_h_4b8939.pth')
parser.add_argument('--sam_type', type=str, default='vit_h')
parser.add_argument('--lr', type=float, default=1e-3)
parser.add_argument('--train_epoch', type=int, default=1000)
parser.add_argument('--log_epoch', type=int, default=200)
parser.add_argument('--ref_idx', type=str, default='00')
args = parser.parse_args()
return args
def main():
args = get_arguments()
print("Args:", args)
images_path = args.data + '/Images/'
masks_path = args.data + '/Annotations/'
output_path = './outputs/' + args.outdir
if not os.path.exists('./outputs/'):
os.mkdir('./outputs/')
for obj_name in os.listdir(images_path):
if ".DS" not in obj_name:
persam_f(args, obj_name, images_path, masks_path, output_path)
def persam_f(args, obj_name, images_path, masks_path, output_path):
print("\n------------> Segment " + obj_name)
# Path preparation
ref_image_path = os.path.join(images_path, obj_name, args.ref_idx + '.jpg')
ref_mask_path = os.path.join(masks_path, obj_name, args.ref_idx + '.png')
test_images_path = os.path.join(images_path, obj_name)
output_path = os.path.join(output_path, obj_name)
os.makedirs(output_path, exist_ok=True)
# Load images and masks
ref_image = cv2.imread(ref_image_path)
ref_image = cv2.cvtColor(ref_image, cv2.COLOR_BGR2RGB)
ref_mask = cv2.imread(ref_mask_path)
ref_mask = cv2.cvtColor(ref_mask, cv2.COLOR_BGR2RGB)
gt_mask = torch.tensor(ref_mask)[:, :, 0] > 0
gt_mask = gt_mask.float().unsqueeze(0).flatten(1).cuda()
print("======> Load SAM" )
if args.sam_type == 'vit_h':
sam_type, sam_ckpt = 'vit_h', 'sam_vit_h_4b8939.pth'
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).cuda()
elif args.sam_type == 'vit_t':
sam_type, sam_ckpt = 'vit_t', 'weights/mobile_sam.pt'
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry[sam_type](checkpoint=sam_ckpt).to(device=device)
sam.eval()
for name, param in sam.named_parameters():
param.requires_grad = False
predictor = SamPredictor(sam)
print("======> Obtain Self Location Prior" )
# Image features encoding
ref_mask = predictor.set_image(ref_image, ref_mask)
ref_feat = predictor.features.squeeze().permute(1, 2, 0)
ref_mask = F.interpolate(ref_mask, size=ref_feat.shape[0: 2], mode="bilinear")
ref_mask = ref_mask.squeeze()[0]
# Target feature extraction
target_feat = ref_feat[ref_mask > 0]
target_feat_mean = target_feat.mean(0)
target_feat_max = torch.max(target_feat, dim=0)[0]
target_feat = (target_feat_max / 2 + target_feat_mean / 2).unsqueeze(0)
# Cosine similarity
h, w, C = ref_feat.shape
target_feat = target_feat / target_feat.norm(dim=-1, keepdim=True)
ref_feat = ref_feat / ref_feat.norm(dim=-1, keepdim=True)
ref_feat = ref_feat.permute(2, 0, 1).reshape(C, h * w)
sim = target_feat @ ref_feat
sim = sim.reshape(1, 1, h, w)
sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
sim = predictor.model.postprocess_masks(
sim,
input_size=predictor.input_size,
original_size=predictor.original_size).squeeze()
# Positive location prior
topk_xy, topk_label = point_selection(sim, topk=1)
print('======> Start Training')
# Learnable mask weights
mask_weights = Mask_Weights().cuda()
mask_weights.train()
optimizer = torch.optim.AdamW(mask_weights.parameters(), lr=args.lr, eps=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.train_epoch)
for train_idx in range(args.train_epoch):
# Run the decoder
masks, scores, logits, logits_high = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
multimask_output=True)
logits_high = logits_high.flatten(1)
# Weighted sum three-scale masks
weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
logits_high = logits_high * weights
logits_high = logits_high.sum(0).unsqueeze(0)
dice_loss = calculate_dice_loss(logits_high, gt_mask)
focal_loss = calculate_sigmoid_focal_loss(logits_high, gt_mask)
loss = dice_loss + focal_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
scheduler.step()
if train_idx % args.log_epoch == 0:
print('Train Epoch: {:} / {:}'.format(train_idx, args.train_epoch))
current_lr = scheduler.get_last_lr()[0]
print('LR: {:.6f}, Dice_Loss: {:.4f}, Focal_Loss: {:.4f}'.format(current_lr, dice_loss.item(), focal_loss.item()))
mask_weights.eval()
weights = torch.cat((1 - mask_weights.weights.sum(0).unsqueeze(0), mask_weights.weights), dim=0)
weights_np = weights.detach().cpu().numpy()
print('======> Mask weights:\n', weights_np)
print('======> Start Testing')
for test_idx in tqdm(range(len(os.listdir(test_images_path)))):
# Load test image
test_idx = '%02d' % test_idx
test_image_path = test_images_path + '/' + test_idx + '.jpg'
test_image = cv2.imread(test_image_path)
test_image = cv2.cvtColor(test_image, cv2.COLOR_BGR2RGB)
# Image feature encoding
predictor.set_image(test_image)
test_feat = predictor.features.squeeze()
# Cosine similarity
C, h, w = test_feat.shape
test_feat = test_feat / test_feat.norm(dim=0, keepdim=True)
test_feat = test_feat.reshape(C, h * w)
sim = target_feat @ test_feat
sim = sim.reshape(1, 1, h, w)
sim = F.interpolate(sim, scale_factor=4, mode="bilinear")
sim = predictor.model.postprocess_masks(
sim,
input_size=predictor.input_size,
original_size=predictor.original_size).squeeze()
# Positive location prior
topk_xy, topk_label = point_selection(sim, topk=1)
# First-step prediction
masks, scores, logits, logits_high = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
multimask_output=True)
# Weighted sum three-scale masks
logits_high = logits_high * weights.unsqueeze(-1)
logit_high = logits_high.sum(0)
mask = (logit_high > 0).detach().cpu().numpy()
logits = logits * weights_np[..., None]
logit = logits.sum(0)
# Cascaded Post-refinement-1
y, x = np.nonzero(mask)
x_min = x.min()
x_max = x.max()
y_min = y.min()
y_max = y.max()
input_box = np.array([x_min, y_min, x_max, y_max])
masks, scores, logits, _ = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
box=input_box[None, :],
mask_input=logit[None, :, :],
multimask_output=True)
best_idx = np.argmax(scores)
# Cascaded Post-refinement-2
y, x = np.nonzero(masks[best_idx])
x_min = x.min()
x_max = x.max()
y_min = y.min()
y_max = y.max()
input_box = np.array([x_min, y_min, x_max, y_max])
masks, scores, logits, _ = predictor.predict(
point_coords=topk_xy,
point_labels=topk_label,
box=input_box[None, :],
mask_input=logits[best_idx: best_idx + 1, :, :],
multimask_output=True)
best_idx = np.argmax(scores)
# Save masks
plt.figure(figsize=(10, 10))
plt.imshow(test_image)
show_mask(masks[best_idx], plt.gca())
show_points(topk_xy, topk_label, plt.gca())
plt.title(f"Mask {best_idx}", fontsize=18)
plt.axis('off')
vis_mask_output_path = os.path.join(output_path, f'vis_mask_{test_idx}.jpg')
with open(vis_mask_output_path, 'wb') as outfile:
plt.savefig(outfile, format='jpg')
final_mask = masks[best_idx]
mask_colors = np.zeros((final_mask.shape[0], final_mask.shape[1], 3), dtype=np.uint8)
mask_colors[final_mask, :] = np.array([[0, 0, 128]])
mask_output_path = os.path.join(output_path, test_idx + '.png')
cv2.imwrite(mask_output_path, mask_colors)
class Mask_Weights(nn.Module):
def __init__(self):
super().__init__()
self.weights = nn.Parameter(torch.ones(2, 1, requires_grad=True) / 3)
def point_selection(mask_sim, topk=1):
# Top-1 point selection
w, h = mask_sim.shape
topk_xy = mask_sim.flatten(0).topk(topk)[1]
topk_x = (topk_xy // h).unsqueeze(0)
topk_y = (topk_xy - topk_x * h)
topk_xy = torch.cat((topk_y, topk_x), dim=0).permute(1, 0)
topk_label = np.array([1] * topk)
topk_xy = topk_xy.cpu().numpy()
return topk_xy, topk_label
def calculate_dice_loss(inputs, targets, num_masks = 1):
"""
Compute the DICE loss, similar to generalized IOU for masks
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
"""
inputs = inputs.sigmoid()
inputs = inputs.flatten(1)
numerator = 2 * (inputs * targets).sum(-1)
denominator = inputs.sum(-1) + targets.sum(-1)
loss = 1 - (numerator + 1) / (denominator + 1)
return loss.sum() / num_masks
def calculate_sigmoid_focal_loss(inputs, targets, num_masks = 1, alpha: float = 0.25, gamma: float = 2):
"""
Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002.
Args:
inputs: A float tensor of arbitrary shape.
The predictions for each example.
targets: A float tensor with the same shape as inputs. Stores the binary
classification label for each element in inputs
(0 for the negative class and 1 for the positive class).
alpha: (optional) Weighting factor in range (0,1) to balance
positive vs negative examples. Default = -1 (no weighting).
gamma: Exponent of the modulating factor (1 - p_t) to
balance easy vs hard examples.
Returns:
Loss tensor
"""
prob = inputs.sigmoid()
ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction="none")
p_t = prob * targets + (1 - prob) * (1 - targets)
loss = ce_loss * ((1 - p_t) ** gamma)
if alpha >= 0:
alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
loss = alpha_t * loss
return loss.mean(1).sum() / num_masks
if __name__ == '__main__':
main()