diff --git a/configs/mask_rtdetr/_base_/mask_rtdetr_r50vd.yml b/configs/mask_rtdetr/_base_/mask_rtdetr_r50vd.yml index 063f534b9d..37fc0b28ba 100644 --- a/configs/mask_rtdetr/_base_/mask_rtdetr_r50vd.yml +++ b/configs/mask_rtdetr/_base_/mask_rtdetr_r50vd.yml @@ -67,6 +67,7 @@ MaskDINOHead: loss_coeff: {class: 4, bbox: 5, giou: 2, mask: 5, dice: 5} aux_loss: True use_vfl: True + vfl_iou_type: 'mask' matcher: name: HungarianMatcher matcher_coeff: {class: 4, bbox: 5, giou: 2, mask: 5, dice: 5} diff --git a/ppdet/modeling/losses/detr_loss.py b/ppdet/modeling/losses/detr_loss.py index b3fa6b92bd..a322c1e84e 100644 --- a/ppdet/modeling/losses/detr_loss.py +++ b/ppdet/modeling/losses/detr_loss.py @@ -46,6 +46,7 @@ def __init__(self, aux_loss=True, use_focal_loss=False, use_vfl=False, + vfl_iou_type='bbox', use_uni_match=False, uni_match_ind=0): r""" @@ -65,6 +66,7 @@ def __init__(self, self.aux_loss = aux_loss self.use_focal_loss = use_focal_loss self.use_vfl = use_vfl + self.vfl_iou_type = vfl_iou_type self.use_uni_match = use_uni_match self.uni_match_ind = uni_match_ind @@ -329,11 +331,41 @@ def _get_prediction_loss(self, _, target_score = self._get_src_target_assign( logits[-1].detach(), gt_score, match_indices) elif sum(len(a) for a in gt_bbox) > 0: - src_bbox, target_bbox = self._get_src_target_assign( - boxes.detach(), gt_bbox, match_indices) - iou_score = bbox_iou( - bbox_cxcywh_to_xyxy(src_bbox).split(4, -1), - bbox_cxcywh_to_xyxy(target_bbox).split(4, -1)) + if self.vfl_iou_type == 'bbox': + src_bbox, target_bbox = self._get_src_target_assign( + boxes.detach(), gt_bbox, match_indices) + iou_score = bbox_iou( + bbox_cxcywh_to_xyxy(src_bbox).split(4, -1), + bbox_cxcywh_to_xyxy(target_bbox).split(4, -1)) + elif self.vfl_iou_type == 'mask': + assert (masks is not None and gt_mask is not None, + 'Make sure the input has `mask` and `gt_mask`') + assert sum(len(a) for a in gt_mask) > 0 + src_mask, target_mask = self._get_src_target_assign( + masks.detach(), gt_mask, match_indices) + src_mask = F.interpolate( + src_mask.unsqueeze(0), + scale_factor=2, + mode='bilinear', + align_corners=False).squeeze(0) + target_mask = F.interpolate( + target_mask.unsqueeze(0), + size=src_mask.shape[-2:], + mode='bilinear', + align_corners=False).squeeze(0) + src_mask = src_mask.flatten(1) + src_mask = F.sigmoid(src_mask) + src_mask = paddle.where( + src_mask > 0.5, 1., 0.).astype(masks.dtype) + target_mask = target_mask.flatten(1) + target_mask = paddle.where( + target_mask > 0.5, 1., 0.).astype(masks.dtype) + inter = (src_mask * target_mask).sum(1) + union = src_mask.sum(1) + target_mask.sum(1) - inter + iou_score = (inter + 1e-2) / (union + 1e-2) + iou_score = iou_score.unsqueeze(-1) + else: + iou_score = None else: iou_score = None else: @@ -502,11 +534,12 @@ def __init__(self, aux_loss=True, use_focal_loss=False, use_vfl=False, + vfl_iou_type='bbox', num_sample_points=12544, oversample_ratio=3.0, important_sample_ratio=0.75): super(MaskDINOLoss, self).__init__(num_classes, matcher, loss_coeff, - aux_loss, use_focal_loss, use_vfl) + aux_loss, use_focal_loss, use_vfl, vfl_iou_type) assert oversample_ratio >= 1 assert important_sample_ratio <= 1 and important_sample_ratio >= 0 diff --git a/ppdet/modeling/transformers/hybrid_encoder.py b/ppdet/modeling/transformers/hybrid_encoder.py index 6f9333cd9c..fcbca6fa5c 100644 --- a/ppdet/modeling/transformers/hybrid_encoder.py +++ b/ppdet/modeling/transformers/hybrid_encoder.py @@ -306,8 +306,8 @@ def __init__(self, in_channels=[256, 256, 256], fpn_strides=[32, 16, 8], feat_channels=256, - dropout_ratio=0.0, - mask_dim=8, + dropout_ratio=0.1, + mask_dim=32, align_corners=False, act='swish'): super(MaskFeatFPN, self).__init__() @@ -332,7 +332,8 @@ def __init__(self, in_c = in_channels[i] if k == 0 else feat_channels scale_head.append( nn.Sequential( - BaseConv(in_c, feat_channels, 3, 1, act=act)) + BaseConv( + in_c, feat_channels, 3, 1, act=act)), ) if fpn_strides[i] != fpn_strides[0]: scale_head.append( @@ -343,7 +344,7 @@ def __init__(self, self.scale_heads.append(nn.Sequential(*scale_head)) - self.output_proj = nn.Conv2D(feat_channels, mask_dim, 1) + self.output_proj = nn.Conv2D(feat_channels + 2, mask_dim, 1) def forward(self, inputs): x = [inputs[i] for i in self.reorder_index] @@ -358,6 +359,18 @@ def forward(self, inputs): if self.dropout_ratio > 0: output = self.dropout(output) + + bs, _, h, w = output.shape + x_range = paddle.linspace(-1, 1, w, dtype='float32') + y_range = paddle.linspace(-1, 1, h, dtype='float32') + y, x = paddle.meshgrid([y_range, x_range]) + x = paddle.unsqueeze(x, [0, 1]) + y = paddle.unsqueeze(y, [0, 1]) + y = paddle.expand(y, shape=[bs, 1, -1, -1]) + x = paddle.expand(x, shape=[bs, 1, -1, -1]) + coord_feat = paddle.concat([x, y], axis=1) + + output = paddle.concat([coord_feat, output], axis=1) output = self.output_proj(output) return output diff --git a/ppdet/modeling/transformers/mask_rtdetr_transformer.py b/ppdet/modeling/transformers/mask_rtdetr_transformer.py index 06168a8711..338a8a5486 100644 --- a/ppdet/modeling/transformers/mask_rtdetr_transformer.py +++ b/ppdet/modeling/transformers/mask_rtdetr_transformer.py @@ -93,7 +93,7 @@ def forward(self, if self.training: logits_, masks_ = _get_pred_class_and_mask( - dec_norm(output), mask_feat, dec_norm, + output, mask_feat, dec_norm, score_head, mask_query_head) dec_out_logits.append(logits_) dec_out_masks.append(masks_) @@ -105,7 +105,7 @@ def forward(self, inverse_sigmoid(ref_points))) elif i == self.eval_idx: logits_, masks_ = _get_pred_class_and_mask( - dec_norm(output), mask_feat, dec_norm, + output, mask_feat, dec_norm, score_head, mask_query_head) dec_out_logits.append(logits_) dec_out_masks.append(masks_)