Skip to content

Commit

Permalink
add relative coords
Browse files Browse the repository at this point in the history
  • Loading branch information
Epiphqny authored Apr 7, 2020
1 parent ea3f717 commit 1b03b70
Showing 1 changed file with 68 additions and 44 deletions.
112 changes: 68 additions & 44 deletions fcos/modeling/fcos/fcos_outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def _get_ground_truth(self):
training_targets = self.compute_targets_for_locations(
locations, self.gt_instances, loc_to_size_range
)

# transpose im first training_targets to level first ones
training_targets = {
k: self._transpose(v, num_loc_list) for k, v in training_targets.items()
Expand All @@ -168,7 +168,7 @@ def _get_ground_truth(self):
for l in range(len(reg_targets)):
reg_targets[l] = reg_targets[l] / float(self.strides[l])

return training_targets
return training_targets, locations

def get_sample_region(self, gt, strides, num_loc_list, loc_xs, loc_ys, radius=1):
num_gts = gt.shape[0]
Expand Down Expand Up @@ -269,7 +269,7 @@ def losses(self):
dict[loss name -> loss value]: A dict mapping from loss name to loss value.
"""

training_targets = self._get_ground_truth()
training_targets, locations = self._get_ground_truth()
labels, reg_targets, matched_idxes, im_idxes = training_targets["labels"], training_targets["reg_targets"], training_targets["matched_idxes"], training_targets["im_idxes"]

# Collect all logits and regression predictions over feature maps
Expand Down Expand Up @@ -319,6 +319,8 @@ def losses(self):
[
x.permute(0, 2, 3, 1).reshape(-1, 169) for x in self.controllers
], dim=0,)
locations = cat([locations//self.strides[0]]*4)


return self.fcos_losses(
labels,
Expand All @@ -331,7 +333,8 @@ def losses(self):
self.focal_loss_gamma,
self.iou_loss,
matched_idxes,
im_idxes
im_idxes,
locations
)

def predict_proposals(self):
Expand Down Expand Up @@ -365,29 +368,38 @@ def predict_proposals(self):

def forward_for_mask(self, boxlists):
N, dim, h, w = self.masks.shape
grid_x = torch.arange(w).view(1,-1).float().repeat(h,1).cuda() / (w-1) * 2 - 1
grid_y = torch.arange(h).view(-1,1).float().repeat(1,w).cuda() / (h-1) * 2 - 1
x_map = grid_x.view(1, 1, h, w).repeat(N, 1, 1, 1)
y_map = grid_y.view(1, 1, h, w).repeat(N, 1, 1, 1)
masks_feat = torch.cat((self.masks, x_map, y_map), dim=1)
o_h = int(h * self.strides[0])
o_w = int(w * self.strides[0])
x_range = torch.linspace(-1, 1, w, device=self.masks.device)
y_range = torch.linspace(-1, 1, h, device=self.masks.device)
y, x = torch.meshgrid(y_range, x_range)
x = x.unsqueeze(0).unsqueeze(0)
y = y.unsqueeze(0).unsqueeze(0)
grid = torch.cat([x,y],1)
#masks_feat = torch.cat((self.masks, x_map, y_map), dim=1)
#o_h = int(h * self.strides[0])
#o_w = int(w * self.strides[0])
for im in range(N):
boxlist = boxlists[im]
input_h, input_w = boxlist.image_size
mask = masks_feat[None, im]
pred_boxes = boxlists[im].pred_boxes.tensor / 8
# check if height and width is correct
center_x = torch.clamp((pred_boxes[:,0] + pred_boxes[:,2])/2, min=0, max=w-1).long()
center_y = torch.clamp((pred_boxes[:,1] + pred_boxes[:,3])/2, min=0, max=h-1).long()
offset_x = x_range[center_x].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
offset_y = y_range[center_y].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
offset_xy = torch.cat([offset_x,offset_y],1)
coords_feat = grid-offset_xy
mask_feat = self.masks[None, im]
ins_num = boxlist.controllers.shape[0]
weights1 = boxlist.controllers[:,:80].reshape(-1,8,10).reshape(-1,10).unsqueeze(-1).unsqueeze(-1)
bias1 = boxlist.controllers[:, 80:88].flatten()
weights2 = boxlist.controllers[:, 88:152].reshape(-1,8,8).reshape(-1,8).unsqueeze(-1).unsqueeze(-1)
bias2 = boxlist.controllers[:, 152:160].flatten()
weights3 = boxlist.controllers[:, 160:168].unsqueeze(-1).unsqueeze(-1)
bias3 = boxlist.controllers[:,168:169].flatten()

conv1 = F.conv2d(mask,weights1,bias1).relu()
conv2 = F.conv2d(conv1, weights2, bias2, groups = ins_num).relu()
masks_per_image = F.conv2d(conv2, weights3, bias3, groups = ins_num)
#masks = interpolate(masks_per_image, size = (o_h,o_w), mode="bilinear", align_corners=False).sigmoid()
mask_feat = torch.cat([mask_feat]*ins_num,dim=0)
comb_feat = torch.cat((mask_feat, coords_feat),dim=1).view(1,-1,h,w)
weight1, bias1, weight2, bias2, weight3, bias3 = torch.split(boxlist.controllers,[80, 8, 64, 8, 8, 1], dim=1)
bias1, bias2, bias3 = bias1.flatten(), bias2.flatten(),bias3.flatten()
weight1 = weight1.reshape(-1,8,10).reshape(-1,10).unsqueeze(-1).unsqueeze(-1)
weight2 = weight2.reshape(-1,8,8).reshape(-1,8).unsqueeze(-1).unsqueeze(-1)
weight3 = weight3.unsqueeze(-1).unsqueeze(-1)
conv1 = F.conv2d(comb_feat, weight1, bias1, groups = ins_num).relu()
conv2 = F.conv2d(conv1, weight2, bias2, groups = ins_num).relu()
masks_per_image = F.conv2d(conv2, weight3, bias3, groups = ins_num)
masks = aligned_bilinear(masks_per_image, self.strides[0]).sigmoid()
masks = masks[:, :, :input_h, :input_w].permute(1,0,2,3)
boxlist.pred_masks = masks
Expand Down Expand Up @@ -523,7 +535,8 @@ def fcos_losses(
focal_loss_gamma,
iou_loss,
matched_idxes,
im_idxes
im_idxes,
locations
):
num_classes = logits_pred.size(1)
labels = labels.flatten()
Expand Down Expand Up @@ -552,6 +565,7 @@ def fcos_losses(
controllers_pred = controllers_pred[pos_inds]
matched_idxes = matched_idxes[pos_inds]
im_idxes = im_idxes[pos_inds]
locations = locations[pos_inds]

ctrness_targets = compute_ctrness_targets(reg_targets)
ctrness_targets_sum = ctrness_targets.sum()
Expand All @@ -570,42 +584,52 @@ def fcos_losses(
) / num_pos_avg

# for CondInst
N, C, h, w = self.masks.shape
grid_x = torch.arange(w).view(1,-1).float().repeat(h,1).cuda() / (w-1) * 2 - 1
grid_y = torch.arange(h).view(-1,1).float().repeat(1,w).cuda() / (h-1) * 2 - 1
x_map = grid_x.view(1, 1, h, w).repeat(N, 1, 1, 1)
y_map = grid_y.view(1, 1, h, w).repeat(N, 1, 1, 1)
masks_feat = torch.cat((self.masks, x_map, y_map), dim=1)
batch_ins = pos_inds.shape[0]
N, C, h, w = self.masks.shape
center_x=torch.clamp(locations[:,0],min=0,max=w-1).long()
center_y=torch.clamp(locations[:,1],min=0,max=h-1).long()
x_range = torch.linspace(-1, 1, w, device=self.masks.device)
y_range = torch.linspace(-1, 1, h, device=self.masks.device)
y, x = torch.meshgrid(y_range, x_range)
x = x.unsqueeze(0).unsqueeze(0)
y = y.unsqueeze(0).unsqueeze(0)
grid = torch.cat([x,y],1)
offset_x = x_range[center_x].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
offset_y = y_range[center_y].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
offset_xy = torch.cat([offset_x,offset_y],1)
coords_feat = grid-offset_xy

masks_feat = self.masks
r_h = int(h * self.strides[0])
r_w = int(w * self.strides[0])
targets_masks = [target_im.gt_masks.tensor for target_im in self.gt_instances]
masks_t = self.prepare_masks(h, w, r_h, r_w, targets_masks)
mask_loss = masks_feat[0].new_tensor(0.0)
batch_ins = im_idxes.shape[0]
# for each image
for i in range(N):
inds = (im_idxes==i).nonzero().flatten()
ins_num = inds.shape[0]
if ins_num > 0:
controllers = controllers_pred[inds]
coord_feat=coords_feat[inds]
mask_feat = masks_feat[None, i]
weights1 = controllers[:, :80].reshape(-1,8,10).reshape(-1,10).unsqueeze(-1).unsqueeze(-1)
bias1 = controllers[:, 80:88].flatten()
weights2 = controllers[:, 88:152].reshape(-1,8,8).reshape(-1,8).unsqueeze(-1).unsqueeze(-1)
bias2 = controllers[:, 152:160].flatten()
weights3 = controllers[:, 160:168].unsqueeze(-1).unsqueeze(-1)
bias3 = controllers[:,168:169].flatten()
conv1 = F.conv2d(mask_feat,weights1,bias1).relu()
conv2 = F.conv2d(conv1, weights2, bias2, groups = ins_num).relu()
#masks_per_image = F.conv2d(conv2, weights3, bias3, groups = ins_num)[0].sigmoid()
masks_per_image = F.conv2d(conv2, weights3, bias3, groups = ins_num)
masks_per_image = aligned_bilinear(masks_per_image, self.strides[0])[0].sigmoid()
mask_feat = torch.cat([mask_feat]*ins_num,dim=0)
comb_feat = torch.cat((mask_feat, coord_feat),dim=1).view(1,-1,h,w)
weight1, bias1, weight2, bias2, weight3, bias3 = torch.split(controllers, [80, 8, 64, 8, 8, 1], dim=1)
bias1, bias2, bias3 = bias1.flatten(), bias2.flatten(), bias3.flatten()
weight1 = weight1.reshape(-1,8,10).reshape(-1,10).unsqueeze(-1).unsqueeze(-1)
weight2 = weight2.reshape(-1,8,8).reshape(-1,8).unsqueeze(-1).unsqueeze(-1)
weight3 = weight3.unsqueeze(-1).unsqueeze(-1)
conv1 = F.conv2d(comb_feat, weight1, bias1, groups = ins_num).relu()
conv2 = F.conv2d(conv1, weight2, bias2, groups = ins_num).relu()
masks_per_image = F.conv2d(conv2, weight3, bias3, groups = ins_num)
masks_per_image = aligned_bilinear(masks_per_image, self.strides[0])[0].sigmoid()

for j in range(ins_num):
ind = inds[j]
mask_gt = masks_t[i][matched_idxes[ind]].float()
mask_pred = masks_per_image[j]
mask_loss += self.dice_loss(mask_pred, mask_gt)

if batch_ins > 0:
mask_loss = mask_loss / batch_ins

Expand Down

0 comments on commit 1b03b70

Please sign in to comment.