From c1b86778aacee2d17bf03dfd8d6b4c8c3c4481cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Haian=20Huang=28=E6=B7=B1=E5=BA=A6=E7=9C=B8=29?= <1286304229@qq.com> Date: Tue, 22 Aug 2023 11:45:50 +0800 Subject: [PATCH] Support CO-DETR (#10740) Co-authored-by: huanghaian --- README.md | 5 + README_zh-CN.md | 5 + mmdet/models/dense_heads/detr_head.py | 26 +- mmdet/models/dense_heads/dino_head.py | 29 +- projects/CO-DETR/README.md | 32 + projects/CO-DETR/codetr/__init__.py | 13 + projects/CO-DETR/codetr/co_atss_head.py | 153 ++ projects/CO-DETR/codetr/co_dino_head.py | 677 ++++++++ projects/CO-DETR/codetr/co_roi_head.py | 108 ++ projects/CO-DETR/codetr/codetr.py | 320 ++++ projects/CO-DETR/codetr/transformer.py | 1376 +++++++++++++++++ .../codino/co_dino_5scale_r50_8xb2_1x_coco.py | 68 + .../co_dino_5scale_r50_lsj_8xb2_1x_coco.py | 359 +++++ .../co_dino_5scale_r50_lsj_8xb2_3x_coco.py | 4 + ...dino_5scale_swin_l_16xb1_16e_o365tococo.py | 115 ++ .../co_dino_5scale_swin_l_16xb1_1x_coco.py | 31 + .../co_dino_5scale_swin_l_16xb1_3x_coco.py | 6 + ...co_dino_5scale_swin_l_lsj_16xb1_1x_coco.py | 72 + ...co_dino_5scale_swin_l_lsj_16xb1_3x_coco.py | 6 + tools/model_converters/glip_to_mmdet.py | 3 +- tools/model_converters/swinv1_to_mmdet.py | 86 ++ 21 files changed, 3486 insertions(+), 8 deletions(-) create mode 100644 projects/CO-DETR/README.md create mode 100644 projects/CO-DETR/codetr/__init__.py create mode 100644 projects/CO-DETR/codetr/co_atss_head.py create mode 100644 projects/CO-DETR/codetr/co_dino_head.py create mode 100644 projects/CO-DETR/codetr/co_roi_head.py create mode 100644 projects/CO-DETR/codetr/codetr.py create mode 100644 projects/CO-DETR/codetr/transformer.py create mode 100644 projects/CO-DETR/configs/codino/co_dino_5scale_r50_8xb2_1x_coco.py create mode 100644 projects/CO-DETR/configs/codino/co_dino_5scale_r50_lsj_8xb2_1x_coco.py create mode 100644 projects/CO-DETR/configs/codino/co_dino_5scale_r50_lsj_8xb2_3x_coco.py create mode 100644 projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_16e_o365tococo.py create mode 100644 projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_1x_coco.py create mode 100644 projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_3x_coco.py create mode 100644 projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_lsj_16xb1_1x_coco.py create mode 100644 projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_lsj_16xb1_3x_coco.py create mode 100644 tools/model_converters/swinv1_to_mmdet.py diff --git a/README.md b/README.md index 89748a970d0..0b5a1d16c39 100644 --- a/README.md +++ b/README.md @@ -236,9 +236,12 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
  • DAB-DETR (ICLR'2022)
  • DINO (ICLR'2023)
  • GLIP (CVPR'2022)
  • +
  • DDQ (CVPR'2023)
  • DiffusionDet (ArXiv'2023)
  • EfficientDet (CVPR'2020)
  • +
  • ViTDet (ECCV'2022)
  • Detic (ECCV'2022)
  • +
  • CO-DETR (ICCV'2023)
  • @@ -260,6 +263,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
  • SparseInst (CVPR'2022)
  • RTMDet (ArXiv'2022)
  • BoxInst (CVPR'2021)
  • +
  • ConvNeXt-V2 (Arxiv'2023)
  • @@ -267,6 +271,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
  • Panoptic FPN (CVPR'2019)
  • MaskFormer (NeurIPS'2021)
  • Mask2Former (ArXiv'2021)
  • +
  • XDecoder (CVPR'2023)
  • diff --git a/README_zh-CN.md b/README_zh-CN.md index 7f2713dec75..cd3813ade8d 100644 --- a/README_zh-CN.md +++ b/README_zh-CN.md @@ -237,9 +237,12 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope
  • DAB-DETR (ICLR'2022)
  • DINO (ICLR'2023)
  • GLIP (CVPR'2022)
  • +
  • DDQ (CVPR'2023)
  • DiffusionDet (ArXiv'2023)
  • EfficientDet (CVPR'2020)
  • +
  • ViTDet (ECCV'2022)
  • Detic (ECCV'2022)
  • +
  • CO-DETR (ICCV'2023)
  • @@ -261,6 +264,7 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope
  • SparseInst (CVPR'2022)
  • RTMDet (ArXiv'2022)
  • BoxInst (CVPR'2021)
  • +
  • ConvNeXt-V2 (Arxiv'2023)
  • @@ -268,6 +272,7 @@ MMDetection 是一个基于 PyTorch 的目标检测开源工具箱。它是 [Ope
  • Panoptic FPN (CVPR'2019)
  • MaskFormer (NeurIPS'2021)
  • Mask2Former (ArXiv'2021)
  • +
  • XDecoder (CVPR'2023)
  • diff --git a/mmdet/models/dense_heads/detr_head.py b/mmdet/models/dense_heads/detr_head.py index 42a94d1ae9c..61545ce364d 100644 --- a/mmdet/models/dense_heads/detr_head.py +++ b/mmdet/models/dense_heads/detr_head.py @@ -12,9 +12,11 @@ from mmdet.registry import MODELS, TASK_UTILS from mmdet.structures import SampleList -from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh +from mmdet.structures.bbox import (bbox_cxcywh_to_xyxy, bbox_overlaps, + bbox_xyxy_to_cxcywh) from mmdet.utils import (ConfigType, InstanceList, OptInstanceList, OptMultiConfig, reduce_mean) +from ..losses import QualityFocalLoss from ..utils import multi_apply @@ -290,8 +292,26 @@ def loss_by_feat_single(self, cls_scores: Tensor, bbox_preds: Tensor, cls_scores.new_tensor([cls_avg_factor])) cls_avg_factor = max(cls_avg_factor, 1) - loss_cls = self.loss_cls( - cls_scores, labels, label_weights, avg_factor=cls_avg_factor) + if isinstance(self.loss_cls, QualityFocalLoss): + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().squeeze(1) + scores = label_weights.new_zeros(labels.shape) + pos_bbox_targets = bbox_targets[pos_inds] + pos_decode_bbox_targets = bbox_cxcywh_to_xyxy(pos_bbox_targets) + pos_bbox_pred = bbox_preds.reshape(-1, 4)[pos_inds] + pos_decode_bbox_pred = bbox_cxcywh_to_xyxy(pos_bbox_pred) + scores[pos_inds] = bbox_overlaps( + pos_decode_bbox_pred.detach(), + pos_decode_bbox_targets, + is_aligned=True) + loss_cls = self.loss_cls( + cls_scores, (labels, scores), + label_weights, + avg_factor=cls_avg_factor) + else: + loss_cls = self.loss_cls( + cls_scores, labels, label_weights, avg_factor=cls_avg_factor) # Compute the average number of gt boxes across all gpus, for # normalization purposes diff --git a/mmdet/models/dense_heads/dino_head.py b/mmdet/models/dense_heads/dino_head.py index 889ff381100..54f46d1474f 100644 --- a/mmdet/models/dense_heads/dino_head.py +++ b/mmdet/models/dense_heads/dino_head.py @@ -7,8 +7,10 @@ from mmdet.registry import MODELS from mmdet.structures import SampleList -from mmdet.structures.bbox import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh +from mmdet.structures.bbox import (bbox_cxcywh_to_xyxy, bbox_overlaps, + bbox_xyxy_to_cxcywh) from mmdet.utils import InstanceList, OptInstanceList, reduce_mean +from ..losses import QualityFocalLoss from ..utils import multi_apply from .deformable_detr_head import DeformableDETRHead @@ -248,8 +250,29 @@ def _loss_dn_single(self, dn_cls_scores: Tensor, dn_bbox_preds: Tensor, cls_avg_factor = max(cls_avg_factor, 1) if len(cls_scores) > 0: - loss_cls = self.loss_cls( - cls_scores, labels, label_weights, avg_factor=cls_avg_factor) + if isinstance(self.loss_cls, QualityFocalLoss): + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().squeeze(1) + scores = label_weights.new_zeros(labels.shape) + pos_bbox_targets = bbox_targets[pos_inds] + pos_decode_bbox_targets = bbox_cxcywh_to_xyxy(pos_bbox_targets) + pos_bbox_pred = dn_bbox_preds.reshape(-1, 4)[pos_inds] + pos_decode_bbox_pred = bbox_cxcywh_to_xyxy(pos_bbox_pred) + scores[pos_inds] = bbox_overlaps( + pos_decode_bbox_pred.detach(), + pos_decode_bbox_targets, + is_aligned=True) + loss_cls = self.loss_cls( + cls_scores, (labels, scores), + weight=label_weights, + avg_factor=cls_avg_factor) + else: + loss_cls = self.loss_cls( + cls_scores, + labels, + label_weights, + avg_factor=cls_avg_factor) else: loss_cls = torch.zeros( 1, dtype=cls_scores.dtype, device=cls_scores.device) diff --git a/projects/CO-DETR/README.md b/projects/CO-DETR/README.md new file mode 100644 index 00000000000..787592ade50 --- /dev/null +++ b/projects/CO-DETR/README.md @@ -0,0 +1,32 @@ +# CO-DETR + +> [DETRs with Collaborative Hybrid Assignments Training](https://arxiv.org/abs/2211.12860) + + + +## Abstract + +In this paper, we provide the observation that too few queries assigned as positive samples in DETR with one-to-one set matching leads to sparse supervision on the encoder's output which considerably hurt the discriminative feature learning of the encoder and vice visa for attention learning in the decoder. To alleviate this, we present a novel collaborative hybrid assignments training scheme, namely Co-DETR, to learn more efficient and effective DETR-based detectors from versatile label assignment manners. This new training scheme can easily enhance the encoder's learning ability in end-to-end detectors by training the multiple parallel auxiliary heads supervised by one-to-many label assignments such as ATSS and Faster RCNN. In addition, we conduct extra customized positive queries by extracting the positive coordinates from these auxiliary heads to improve the training efficiency of positive samples in the decoder. In inference, these auxiliary heads are discarded and thus our method introduces no additional parameters and computational cost to the original detector while requiring no hand-crafted non-maximum suppression (NMS). We conduct extensive experiments to evaluate the effectiveness of the proposed approach on DETR variants, including DAB-DETR, Deformable-DETR, and DINO-Deformable-DETR. The state-of-the-art DINO-Deformable-DETR with Swin-L can be improved from 58.5% to 59.5% AP on COCO val. Surprisingly, incorporated with ViT-L backbone, we achieve 66.0% AP on COCO test-dev and 67.9% AP on LVIS val, outperforming previous methods by clear margins with much fewer model sizes. + +
    + +
    + +## Results and Models + +| Model | Backbone | Epochs | Aug | Dataset | box AP | Config | Download | +| :-------: | :------: | :----: | :--: | :---------------------------: | :----: | :--------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| Co-DINO | R50 | 12 | LSJ | COCO | 52.0 | [config](configs/codino/co_dino_5scale_r50_lsj_8xb2_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_r50_lsj_8xb2_1x_coco/co_dino_5scale_r50_lsj_8xb2_1x_coco-69a72d67.pth)\\ [log](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_r50_lsj_8xb2_1x_coco/co_dino_5scale_r50_lsj_8xb2_1x_coco_20230818_150457.json) | +| Co-DINO\* | R50 | 12 | DETR | COCO | 52.1 | [config](configs/codino/co_dino_5scale_r50_8xb2_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_r50_1x_coco-7481f903.pth) | +| Co-DINO\* | R50 | 36 | LSJ | COCO | 54.8 | [config](configs/codino/co_dino_5scale_r50_lsj_8xb2_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_lsj_r50_3x_coco-fe5a6829.pth) | +| Co-DINO\* | Swin-L | 12 | DETR | COCO | 58.9 | [config](configs/codino/co_dino_5scale_swin_l_16xb1_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_swin_large_1x_coco-27c13da4.pth) | +| Co-DINO\* | Swin-L | 12 | LSJ | COCO | 59.3 | [config](configs/codino/co_dino_5scale_swin_l_lsj_16xb1_1x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_lsj_swin_large_1x_coco-3af73af2.pth) | +| Co-DINO\* | Swin-L | 36 | DETR | COCO | 60.0 | [config](configs/codino/co_dino_5scale_swin_l_16xb1_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_swin_large_3x_coco-d7a6d8af.pth) | +| Co-DINO\* | Swin-L | 36 | LSJ | COCO | 60.7 | [config](configs/codino/co_dino_5scale_swin_l_lsj_16xb1_3x_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_lsj_swin_large_1x_coco-3af73af2.pth) | +| Co-DINO\* | Swin-L | 16 | DETR | Objects365 pre-trained + COCO | 64.1 | [config](configs/codino/co_dino_5scale_swin_l_16xb1_16e_o365tococo.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_swin_large_16e_o365tococo-614254c9.pth) | + +Note + +- Models labeled * are not trained by us, but from [CO-DETR](https://github.com/Sense-X/Co-DETR) official website. +- We find that the performance is unstable and may fluctuate by about 0.3 mAP. +- If you want to save GPU memory by enabling checkpointing, please use the `pip install fairscale` command. diff --git a/projects/CO-DETR/codetr/__init__.py b/projects/CO-DETR/codetr/__init__.py new file mode 100644 index 00000000000..2ca4c02d9f7 --- /dev/null +++ b/projects/CO-DETR/codetr/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .co_atss_head import CoATSSHead +from .co_dino_head import CoDINOHead +from .co_roi_head import CoStandardRoIHead +from .codetr import CoDETR +from .transformer import (CoDinoTransformer, DetrTransformerDecoderLayer, + DetrTransformerEncoder, DinoTransformerDecoder) + +__all__ = [ + 'CoDETR', 'CoDinoTransformer', 'DinoTransformerDecoder', 'CoDINOHead', + 'CoATSSHead', 'CoStandardRoIHead', 'DetrTransformerEncoder', + 'DetrTransformerDecoderLayer' +] diff --git a/projects/CO-DETR/codetr/co_atss_head.py b/projects/CO-DETR/codetr/co_atss_head.py new file mode 100644 index 00000000000..c6ae0180da7 --- /dev/null +++ b/projects/CO-DETR/codetr/co_atss_head.py @@ -0,0 +1,153 @@ +from typing import List + +import torch +from torch import Tensor + +from mmdet.models.dense_heads import ATSSHead +from mmdet.models.utils import images_to_levels, multi_apply +from mmdet.registry import MODELS +from mmdet.utils import InstanceList, OptInstanceList, reduce_mean + + +@MODELS.register_module() +class CoATSSHead(ATSSHead): + + def loss_by_feat( + self, + cls_scores: List[Tensor], + bbox_preds: List[Tensor], + centernesses: List[Tensor], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None) -> dict: + """Calculate the loss based on the features extracted by the detection + head. + + Args: + cls_scores (list[Tensor]): Box scores for each scale level + Has shape (N, num_anchors * num_classes, H, W) + bbox_preds (list[Tensor]): Box energies / deltas for each scale + level with shape (N, num_anchors * 4, H, W) + centernesses (list[Tensor]): Centerness for each scale + level with shape (N, num_anchors * 1, H, W) + batch_gt_instances (list[:obj:`InstanceData`]): Batch of + gt_instance. It usually includes ``bboxes`` and ``labels`` + attributes. + batch_img_metas (list[dict]): Meta information of each image, e.g., + image size, scaling factor, etc. + batch_gt_instances_ignore (list[:obj:`InstanceData`], Optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. + Defaults to None. + + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores] + assert len(featmap_sizes) == self.prior_generator.num_levels + + device = cls_scores[0].device + anchor_list, valid_flag_list = self.get_anchors( + featmap_sizes, batch_img_metas, device=device) + + cls_reg_targets = self.get_targets( + anchor_list, + valid_flag_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore=batch_gt_instances_ignore) + + (anchor_list, labels_list, label_weights_list, bbox_targets_list, + bbox_weights_list, avg_factor, ori_anchors, ori_labels, + ori_bbox_targets) = cls_reg_targets + + avg_factor = reduce_mean( + torch.tensor(avg_factor, dtype=torch.float, device=device)).item() + + losses_cls, losses_bbox, loss_centerness, \ + bbox_avg_factor = multi_apply( + self.loss_by_feat_single, + anchor_list, + cls_scores, + bbox_preds, + centernesses, + labels_list, + label_weights_list, + bbox_targets_list, + avg_factor=avg_factor) + + bbox_avg_factor = sum(bbox_avg_factor) + bbox_avg_factor = reduce_mean(bbox_avg_factor).clamp_(min=1).item() + losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox)) + + # diff + pos_coords = (ori_anchors, ori_labels, ori_bbox_targets, 'atss') + return dict( + loss_cls=losses_cls, + loss_bbox=losses_bbox, + loss_centerness=loss_centerness, + pos_coords=pos_coords) + + def get_targets(self, + anchor_list: List[List[Tensor]], + valid_flag_list: List[List[Tensor]], + batch_gt_instances: InstanceList, + batch_img_metas: List[dict], + batch_gt_instances_ignore: OptInstanceList = None, + unmap_outputs: bool = True) -> tuple: + """Get targets for ATSS head. + + This method is almost the same as `AnchorHead.get_targets()`. Besides + returning the targets as the parent method does, it also returns the + anchors as the first element of the returned tuple. + """ + num_imgs = len(batch_img_metas) + assert len(anchor_list) == len(valid_flag_list) == num_imgs + + # anchor number of multi levels + num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]] + num_level_anchors_list = [num_level_anchors] * num_imgs + + # concat all level anchors and flags to a single tensor + for i in range(num_imgs): + assert len(anchor_list[i]) == len(valid_flag_list[i]) + anchor_list[i] = torch.cat(anchor_list[i]) + valid_flag_list[i] = torch.cat(valid_flag_list[i]) + + # compute targets for each image + if batch_gt_instances_ignore is None: + batch_gt_instances_ignore = [None] * num_imgs + (all_anchors, all_labels, all_label_weights, all_bbox_targets, + all_bbox_weights, pos_inds_list, neg_inds_list, + sampling_results_list) = multi_apply( + self._get_targets_single, + anchor_list, + valid_flag_list, + num_level_anchors_list, + batch_gt_instances, + batch_img_metas, + batch_gt_instances_ignore, + unmap_outputs=unmap_outputs) + # Get `avg_factor` of all images, which calculate in `SamplingResult`. + # When using sampling method, avg_factor is usually the sum of + # positive and negative priors. When using `PseudoSampler`, + # `avg_factor` is usually equal to the number of positive priors. + avg_factor = sum( + [results.avg_factor for results in sampling_results_list]) + # split targets to a list w.r.t. multiple levels + anchors_list = images_to_levels(all_anchors, num_level_anchors) + labels_list = images_to_levels(all_labels, num_level_anchors) + label_weights_list = images_to_levels(all_label_weights, + num_level_anchors) + bbox_targets_list = images_to_levels(all_bbox_targets, + num_level_anchors) + bbox_weights_list = images_to_levels(all_bbox_weights, + num_level_anchors) + + # diff + ori_anchors = all_anchors + ori_labels = all_labels + ori_bbox_targets = all_bbox_targets + return (anchors_list, labels_list, label_weights_list, + bbox_targets_list, bbox_weights_list, avg_factor, ori_anchors, + ori_labels, ori_bbox_targets) diff --git a/projects/CO-DETR/codetr/co_dino_head.py b/projects/CO-DETR/codetr/co_dino_head.py new file mode 100644 index 00000000000..192acf97d86 --- /dev/null +++ b/projects/CO-DETR/codetr/co_dino_head.py @@ -0,0 +1,677 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from typing import List + +import torch +import torch.nn as nn +import torch.nn.functional as F +from mmcv.cnn import Linear +from mmcv.ops import batched_nms +from mmengine.structures import InstanceData +from torch import Tensor + +from mmdet.models import DINOHead +from mmdet.models.layers import CdnQueryGenerator +from mmdet.models.layers.transformer import inverse_sigmoid +from mmdet.models.utils import multi_apply +from mmdet.registry import MODELS +from mmdet.structures import SampleList +from mmdet.structures.bbox import (bbox_cxcywh_to_xyxy, bbox_overlaps, + bbox_xyxy_to_cxcywh) +from mmdet.utils import InstanceList, reduce_mean + + +@MODELS.register_module() +class CoDINOHead(DINOHead): + + def __init__(self, + *args, + num_query=900, + transformer=None, + in_channels=2048, + max_pos_coords=300, + dn_cfg=None, + use_zero_padding=False, + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=128, + normalize=True), + **kwargs): + self.with_box_refine = True + self.mixed_selection = True + self.in_channels = in_channels + self.max_pos_coords = max_pos_coords + self.positional_encoding = positional_encoding + self.num_query = num_query + self.use_zero_padding = use_zero_padding + + if 'two_stage_num_proposals' in transformer: + assert transformer['two_stage_num_proposals'] == num_query, \ + 'two_stage_num_proposals must be equal to num_query for DINO' + else: + transformer['two_stage_num_proposals'] = num_query + transformer['as_two_stage'] = True + if self.mixed_selection: + transformer['mixed_selection'] = self.mixed_selection + self.transformer = transformer + self.act_cfg = transformer.get('act_cfg', + dict(type='ReLU', inplace=True)) + + super().__init__(*args, **kwargs) + + self.activate = MODELS.build(self.act_cfg) + self.positional_encoding = MODELS.build(self.positional_encoding) + self.init_denoising(dn_cfg) + + def _init_layers(self): + self.transformer = MODELS.build(self.transformer) + self.embed_dims = self.transformer.embed_dims + assert hasattr(self.positional_encoding, 'num_feats') + num_feats = self.positional_encoding.num_feats + assert num_feats * 2 == self.embed_dims, 'embed_dims should' \ + f' be exactly 2 times of num_feats. Found {self.embed_dims}' \ + f' and {num_feats}.' + """Initialize classification branch and regression branch of head.""" + fc_cls = Linear(self.embed_dims, self.cls_out_channels) + reg_branch = [] + for _ in range(self.num_reg_fcs): + reg_branch.append(Linear(self.embed_dims, self.embed_dims)) + reg_branch.append(nn.ReLU()) + reg_branch.append(Linear(self.embed_dims, 4)) + reg_branch = nn.Sequential(*reg_branch) + + def _get_clones(module, N): + return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) + + # last reg_branch is used to generate proposal from + # encode feature map when as_two_stage is True. + num_pred = (self.transformer.decoder.num_layers + 1) if \ + self.as_two_stage else self.transformer.decoder.num_layers + + self.cls_branches = _get_clones(fc_cls, num_pred) + self.reg_branches = _get_clones(reg_branch, num_pred) + + self.downsample = nn.Sequential( + nn.Conv2d( + self.embed_dims, + self.embed_dims, + kernel_size=3, + stride=2, + padding=1), nn.GroupNorm(32, self.embed_dims)) + + def init_denoising(self, dn_cfg): + if dn_cfg is not None: + dn_cfg['num_classes'] = self.num_classes + dn_cfg['num_matching_queries'] = self.num_query + dn_cfg['embed_dims'] = self.embed_dims + self.dn_generator = CdnQueryGenerator(**dn_cfg) + + def forward(self, + mlvl_feats, + img_metas, + dn_label_query=None, + dn_bbox_query=None, + attn_mask=None): + batch_size = mlvl_feats[0].size(0) + input_img_h, input_img_w = img_metas[0]['batch_input_shape'] + img_masks = mlvl_feats[0].new_ones( + (batch_size, input_img_h, input_img_w)) + for img_id in range(batch_size): + img_h, img_w = img_metas[img_id]['img_shape'] + img_masks[img_id, :img_h, :img_w] = 0 + + mlvl_masks = [] + mlvl_positional_encodings = [] + for feat in mlvl_feats: + mlvl_masks.append( + F.interpolate(img_masks[None], + size=feat.shape[-2:]).to(torch.bool).squeeze(0)) + mlvl_positional_encodings.append( + self.positional_encoding(mlvl_masks[-1])) + + query_embeds = None + hs, inter_references, topk_score, topk_anchor, enc_outputs = \ + self.transformer( + mlvl_feats, + mlvl_masks, + query_embeds, + mlvl_positional_encodings, + dn_label_query, + dn_bbox_query, + attn_mask, + reg_branches=self.reg_branches if self.with_box_refine else None, # noqa:E501 + cls_branches=self.cls_branches if self.as_two_stage else None # noqa:E501 + ) + outs = [] + num_level = len(mlvl_feats) + start = 0 + for lvl in range(num_level): + bs, c, h, w = mlvl_feats[lvl].shape + end = start + h * w + feat = enc_outputs[start:end].permute(1, 2, 0).contiguous() + start = end + outs.append(feat.reshape(bs, c, h, w)) + outs.append(self.downsample(outs[-1])) + + hs = hs.permute(0, 2, 1, 3) + + if dn_label_query is not None and dn_label_query.size(1) == 0: + # NOTE: If there is no target in the image, the parameters of + # label_embedding won't be used in producing loss, which raises + # RuntimeError when using distributed mode. + hs[0] += self.dn_generator.label_embedding.weight[0, 0] * 0.0 + + outputs_classes = [] + outputs_coords = [] + + for lvl in range(hs.shape[0]): + reference = inter_references[lvl] + reference = inverse_sigmoid(reference, eps=1e-3) + outputs_class = self.cls_branches[lvl](hs[lvl]) + tmp = self.reg_branches[lvl](hs[lvl]) + if reference.shape[-1] == 4: + tmp += reference + else: + assert reference.shape[-1] == 2 + tmp[..., :2] += reference + outputs_coord = tmp.sigmoid() + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + + outputs_classes = torch.stack(outputs_classes) + outputs_coords = torch.stack(outputs_coords) + + return outputs_classes, outputs_coords, topk_score, topk_anchor, outs + + def predict(self, + feats: List[Tensor], + batch_data_samples: SampleList, + rescale: bool = True) -> InstanceList: + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + outs = self.forward(feats, batch_img_metas) + + predictions = self.predict_by_feat( + *outs, batch_img_metas=batch_img_metas, rescale=rescale) + + return predictions + + def predict_by_feat(self, + all_cls_scores, + all_bbox_preds, + enc_cls_scores, + enc_bbox_preds, + enc_outputs, + batch_img_metas, + rescale=True): + + cls_scores = all_cls_scores[-1] + bbox_preds = all_bbox_preds[-1] + + result_list = [] + for img_id in range(len(batch_img_metas)): + cls_score = cls_scores[img_id] + bbox_pred = bbox_preds[img_id] + img_meta = batch_img_metas[img_id] + results = self._predict_by_feat_single(cls_score, bbox_pred, + img_meta, rescale) + result_list.append(results) + return result_list + + def _predict_by_feat_single(self, + cls_score: Tensor, + bbox_pred: Tensor, + img_meta: dict, + rescale: bool = True) -> InstanceData: + """Transform outputs from the last decoder layer into bbox predictions + for each image. + + Args: + cls_score (Tensor): Box score logits from the last decoder layer + for each image. Shape [num_queries, cls_out_channels]. + bbox_pred (Tensor): Sigmoid outputs from the last decoder layer + for each image, with coordinate format (cx, cy, w, h) and + shape [num_queries, 4]. + img_meta (dict): Image meta info. + rescale (bool): If True, return boxes in original image + space. Default True. + + Returns: + :obj:`InstanceData`: Detection results of each image + after the post process. + Each item usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert len(cls_score) == len(bbox_pred) # num_queries + max_per_img = self.test_cfg.get('max_per_img', self.num_query) + score_thr = self.test_cfg.get('score_thr', 0) + with_nms = self.test_cfg.get('nms', None) + + img_shape = img_meta['img_shape'] + # exclude background + if self.loss_cls.use_sigmoid: + cls_score = cls_score.sigmoid() + scores, indexes = cls_score.view(-1).topk(max_per_img) + det_labels = indexes % self.num_classes + bbox_index = indexes // self.num_classes + bbox_pred = bbox_pred[bbox_index] + else: + scores, det_labels = F.softmax(cls_score, dim=-1)[..., :-1].max(-1) + scores, bbox_index = scores.topk(max_per_img) + bbox_pred = bbox_pred[bbox_index] + det_labels = det_labels[bbox_index] + + if score_thr > 0: + valid_mask = scores > score_thr + scores = scores[valid_mask] + bbox_pred = bbox_pred[valid_mask] + det_labels = det_labels[valid_mask] + + det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred) + det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1] + det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0] + det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1]) + det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0]) + if rescale: + assert img_meta.get('scale_factor') is not None + det_bboxes /= det_bboxes.new_tensor( + img_meta['scale_factor']).repeat((1, 2)) + + results = InstanceData() + results.bboxes = det_bboxes + results.scores = scores + results.labels = det_labels + + if with_nms and results.bboxes.numel() > 0: + det_bboxes, keep_idxs = batched_nms(results.bboxes, results.scores, + results.labels, + self.test_cfg.nms) + results = results[keep_idxs] + results.scores = det_bboxes[:, -1] + results = results[:max_per_img] + + return results + + def loss(self, x, batch_data_samples): + assert self.dn_generator is not None, '"dn_cfg" must be set' + + batch_gt_instances = [] + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + + dn_label_query, dn_bbox_query, attn_mask, dn_meta = \ + self.dn_generator(batch_data_samples) + + outs = self(x, batch_img_metas, dn_label_query, dn_bbox_query, + attn_mask) + + loss_inputs = outs[:-1] + (batch_gt_instances, batch_img_metas, + dn_meta) + losses = self.loss_by_feat(*loss_inputs) + enc_outputs = outs[-1] + return losses, enc_outputs + + def forward_aux(self, mlvl_feats, img_metas, aux_targets, head_idx): + """Forward function. + + Args: + mlvl_feats (tuple[Tensor]): Features from the upstream + network, each is a 4D-tensor with shape + (N, C, H, W). + img_metas (list[dict]): List of image information. + + Returns: + all_cls_scores (Tensor): Outputs from the classification head, \ + shape [nb_dec, bs, num_query, cls_out_channels]. Note \ + cls_out_channels should includes background. + all_bbox_preds (Tensor): Sigmoid outputs from the regression \ + head with normalized coordinate format (cx, cy, w, h). \ + Shape [nb_dec, bs, num_query, 4]. + enc_outputs_class (Tensor): The score of each point on encode \ + feature map, has shape (N, h*w, num_class). Only when \ + as_two_stage is True it would be returned, otherwise \ + `None` would be returned. + enc_outputs_coord (Tensor): The proposal generate from the \ + encode feature map, has shape (N, h*w, 4). Only when \ + as_two_stage is True it would be returned, otherwise \ + `None` would be returned. + """ + aux_coords, aux_labels, aux_targets, aux_label_weights, \ + aux_bbox_weights, aux_feats, attn_masks = aux_targets + batch_size = mlvl_feats[0].size(0) + input_img_h, input_img_w = img_metas[0]['batch_input_shape'] + img_masks = mlvl_feats[0].new_ones( + (batch_size, input_img_h, input_img_w)) + for img_id in range(batch_size): + img_h, img_w = img_metas[img_id]['img_shape'] + img_masks[img_id, :img_h, :img_w] = 0 + + mlvl_masks = [] + mlvl_positional_encodings = [] + for feat in mlvl_feats: + mlvl_masks.append( + F.interpolate(img_masks[None], + size=feat.shape[-2:]).to(torch.bool).squeeze(0)) + mlvl_positional_encodings.append( + self.positional_encoding(mlvl_masks[-1])) + + query_embeds = None + hs, inter_references = self.transformer.forward_aux( + mlvl_feats, + mlvl_masks, + query_embeds, + mlvl_positional_encodings, + aux_coords, + pos_feats=aux_feats, + reg_branches=self.reg_branches if self.with_box_refine else None, + cls_branches=self.cls_branches if self.as_two_stage else None, + return_encoder_output=True, + attn_masks=attn_masks, + head_idx=head_idx) + + hs = hs.permute(0, 2, 1, 3) + outputs_classes = [] + outputs_coords = [] + + for lvl in range(hs.shape[0]): + reference = inter_references[lvl] + reference = inverse_sigmoid(reference, eps=1e-3) + outputs_class = self.cls_branches[lvl](hs[lvl]) + tmp = self.reg_branches[lvl](hs[lvl]) + if reference.shape[-1] == 4: + tmp += reference + else: + assert reference.shape[-1] == 2 + tmp[..., :2] += reference + outputs_coord = tmp.sigmoid() + outputs_classes.append(outputs_class) + outputs_coords.append(outputs_coord) + + outputs_classes = torch.stack(outputs_classes) + outputs_coords = torch.stack(outputs_coords) + + return outputs_classes, outputs_coords, None, None + + def loss_aux(self, + x, + pos_coords=None, + head_idx=0, + batch_data_samples=None): + batch_gt_instances = [] + batch_img_metas = [] + for data_sample in batch_data_samples: + batch_img_metas.append(data_sample.metainfo) + batch_gt_instances.append(data_sample.gt_instances) + + gt_bboxes = [b.bboxes for b in batch_gt_instances] + gt_labels = [b.labels for b in batch_gt_instances] + + aux_targets = self.get_aux_targets(pos_coords, batch_img_metas, x, + head_idx) + outs = self.forward_aux(x[:-1], batch_img_metas, aux_targets, head_idx) + outs = outs + aux_targets + if gt_labels is None: + loss_inputs = outs + (gt_bboxes, batch_img_metas) + else: + loss_inputs = outs + (gt_bboxes, gt_labels, batch_img_metas) + losses = self.loss_aux_by_feat(*loss_inputs) + return losses + + def get_aux_targets(self, pos_coords, img_metas, mlvl_feats, head_idx): + coords, labels, targets = pos_coords[:3] + head_name = pos_coords[-1] + bs, c = len(coords), mlvl_feats[0].shape[1] + max_num_coords = 0 + all_feats = [] + for i in range(bs): + label = labels[i] + feats = [ + feat[i].reshape(c, -1).transpose(1, 0) for feat in mlvl_feats + ] + feats = torch.cat(feats, dim=0) + bg_class_ind = self.num_classes + pos_inds = ((label >= 0) + & (label < bg_class_ind)).nonzero().squeeze(1) + max_num_coords = max(max_num_coords, len(pos_inds)) + all_feats.append(feats) + max_num_coords = min(self.max_pos_coords, max_num_coords) + max_num_coords = max(9, max_num_coords) + + if self.use_zero_padding: + attn_masks = [] + label_weights = coords[0].new_zeros([bs, max_num_coords]) + else: + attn_masks = None + label_weights = coords[0].new_ones([bs, max_num_coords]) + bbox_weights = coords[0].new_zeros([bs, max_num_coords, 4]) + + aux_coords, aux_labels, aux_targets, aux_feats = [], [], [], [] + + for i in range(bs): + coord, label, target = coords[i], labels[i], targets[i] + feats = all_feats[i] + if 'rcnn' in head_name: + feats = pos_coords[-2][i] + num_coords_per_point = 1 + else: + num_coords_per_point = coord.shape[0] // feats.shape[0] + feats = feats.unsqueeze(1).repeat(1, num_coords_per_point, 1) + feats = feats.reshape(feats.shape[0] * num_coords_per_point, + feats.shape[-1]) + img_meta = img_metas[i] + img_h, img_w = img_meta['img_shape'] + factor = coord.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0) + bg_class_ind = self.num_classes + pos_inds = ((label >= 0) + & (label < bg_class_ind)).nonzero().squeeze(1) + neg_inds = (label == bg_class_ind).nonzero().squeeze(1) + if pos_inds.shape[0] > max_num_coords: + indices = torch.randperm( + pos_inds.shape[0])[:max_num_coords].cuda() + pos_inds = pos_inds[indices] + + coord = bbox_xyxy_to_cxcywh(coord[pos_inds] / factor) + label = label[pos_inds] + target = bbox_xyxy_to_cxcywh(target[pos_inds] / factor) + feat = feats[pos_inds] + + if self.use_zero_padding: + label_weights[i][:len(label)] = 1 + bbox_weights[i][:len(label)] = 1 + attn_mask = torch.zeros([ + max_num_coords, + max_num_coords, + ]).bool().to(coord.device) + else: + bbox_weights[i][:len(label)] = 1 + + if coord.shape[0] < max_num_coords: + padding_shape = max_num_coords - coord.shape[0] + if self.use_zero_padding: + padding_coord = coord.new_zeros([padding_shape, 4]) + padding_label = label.new_ones([padding_shape + ]) * self.num_classes + padding_target = target.new_zeros([padding_shape, 4]) + padding_feat = feat.new_zeros([padding_shape, c]) + attn_mask[coord.shape[0]:, 0:coord.shape[0], ] = True + attn_mask[:, coord.shape[0]:, ] = True + else: + indices = torch.randperm( + neg_inds.shape[0])[:padding_shape].cuda() + neg_inds = neg_inds[indices] + padding_coord = bbox_xyxy_to_cxcywh(coords[i][neg_inds] / + factor) + padding_label = labels[i][neg_inds] + padding_target = bbox_xyxy_to_cxcywh(targets[i][neg_inds] / + factor) + padding_feat = feats[neg_inds] + coord = torch.cat((coord, padding_coord), dim=0) + label = torch.cat((label, padding_label), dim=0) + target = torch.cat((target, padding_target), dim=0) + feat = torch.cat((feat, padding_feat), dim=0) + if self.use_zero_padding: + attn_masks.append(attn_mask.unsqueeze(0)) + aux_coords.append(coord.unsqueeze(0)) + aux_labels.append(label.unsqueeze(0)) + aux_targets.append(target.unsqueeze(0)) + aux_feats.append(feat.unsqueeze(0)) + + if self.use_zero_padding: + attn_masks = torch.cat( + attn_masks, dim=0).unsqueeze(1).repeat(1, 8, 1, 1) + attn_masks = attn_masks.reshape(bs * 8, max_num_coords, + max_num_coords) + else: + attn_masks = None + + aux_coords = torch.cat(aux_coords, dim=0) + aux_labels = torch.cat(aux_labels, dim=0) + aux_targets = torch.cat(aux_targets, dim=0) + aux_feats = torch.cat(aux_feats, dim=0) + aux_label_weights = label_weights + aux_bbox_weights = bbox_weights + return (aux_coords, aux_labels, aux_targets, aux_label_weights, + aux_bbox_weights, aux_feats, attn_masks) + + def loss_aux_by_feat(self, + all_cls_scores, + all_bbox_preds, + enc_cls_scores, + enc_bbox_preds, + aux_coords, + aux_labels, + aux_targets, + aux_label_weights, + aux_bbox_weights, + aux_feats, + attn_masks, + gt_bboxes_list, + gt_labels_list, + img_metas, + gt_bboxes_ignore=None): + num_dec_layers = len(all_cls_scores) + all_labels = [aux_labels for _ in range(num_dec_layers)] + all_label_weights = [aux_label_weights for _ in range(num_dec_layers)] + all_bbox_targets = [aux_targets for _ in range(num_dec_layers)] + all_bbox_weights = [aux_bbox_weights for _ in range(num_dec_layers)] + img_metas_list = [img_metas for _ in range(num_dec_layers)] + all_gt_bboxes_ignore_list = [ + gt_bboxes_ignore for _ in range(num_dec_layers) + ] + + losses_cls, losses_bbox, losses_iou = multi_apply( + self._loss_aux_by_feat_single, all_cls_scores, all_bbox_preds, + all_labels, all_label_weights, all_bbox_targets, all_bbox_weights, + img_metas_list, all_gt_bboxes_ignore_list) + + loss_dict = dict() + # loss of proposal generated from encode feature map. + + # loss from the last decoder layer + loss_dict['loss_cls_aux'] = losses_cls[-1] + loss_dict['loss_bbox_aux'] = losses_bbox[-1] + loss_dict['loss_iou_aux'] = losses_iou[-1] + # loss from other decoder layers + num_dec_layer = 0 + for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1], + losses_bbox[:-1], + losses_iou[:-1]): + loss_dict[f'd{num_dec_layer}.loss_cls_aux'] = loss_cls_i + loss_dict[f'd{num_dec_layer}.loss_bbox_aux'] = loss_bbox_i + loss_dict[f'd{num_dec_layer}.loss_iou_aux'] = loss_iou_i + num_dec_layer += 1 + return loss_dict + + def _loss_aux_by_feat_single(self, + cls_scores, + bbox_preds, + labels, + label_weights, + bbox_targets, + bbox_weights, + img_metas, + gt_bboxes_ignore_list=None): + num_imgs = cls_scores.size(0) + num_q = cls_scores.size(1) + + try: + labels = labels.reshape(num_imgs * num_q) + label_weights = label_weights.reshape(num_imgs * num_q) + bbox_targets = bbox_targets.reshape(num_imgs * num_q, 4) + bbox_weights = bbox_weights.reshape(num_imgs * num_q, 4) + except Exception: + return cls_scores.mean() * 0, cls_scores.mean( + ) * 0, cls_scores.mean() * 0 + + bg_class_ind = self.num_classes + num_total_pos = len( + ((labels >= 0) & (labels < bg_class_ind)).nonzero().squeeze(1)) + num_total_neg = num_imgs * num_q - num_total_pos + + # classification loss + cls_scores = cls_scores.reshape(-1, self.cls_out_channels) + # construct weighted avg_factor to match with the official DETR repo + cls_avg_factor = num_total_pos * 1.0 + \ + num_total_neg * self.bg_cls_weight + if self.sync_cls_avg_factor: + cls_avg_factor = reduce_mean( + cls_scores.new_tensor([cls_avg_factor])) + cls_avg_factor = max(cls_avg_factor, 1) + + bg_class_ind = self.num_classes + pos_inds = ((labels >= 0) + & (labels < bg_class_ind)).nonzero().squeeze(1) + scores = label_weights.new_zeros(labels.shape) + pos_bbox_targets = bbox_targets[pos_inds] + pos_decode_bbox_targets = bbox_cxcywh_to_xyxy(pos_bbox_targets) + pos_bbox_pred = bbox_preds.reshape(-1, 4)[pos_inds] + pos_decode_bbox_pred = bbox_cxcywh_to_xyxy(pos_bbox_pred) + scores[pos_inds] = bbox_overlaps( + pos_decode_bbox_pred.detach(), + pos_decode_bbox_targets, + is_aligned=True) + loss_cls = self.loss_cls( + cls_scores, (labels, scores), + weight=label_weights, + avg_factor=cls_avg_factor) + + # Compute the average number of gt boxes across all gpus, for + # normalization purposes + num_total_pos = loss_cls.new_tensor([num_total_pos]) + num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item() + + # construct factors used for rescale bboxes + factors = [] + for img_meta, bbox_pred in zip(img_metas, bbox_preds): + img_h, img_w = img_meta['img_shape'] + factor = bbox_pred.new_tensor([img_w, img_h, img_w, + img_h]).unsqueeze(0).repeat( + bbox_pred.size(0), 1) + factors.append(factor) + factors = torch.cat(factors, 0) + + # DETR regress the relative position of boxes (cxcywh) in the image, + # thus the learning target is normalized by the image size. So here + # we need to re-scale them for calculating IoU loss + bbox_preds = bbox_preds.reshape(-1, 4) + bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors + bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors + + # regression IoU loss, defaultly GIoU loss + loss_iou = self.loss_iou( + bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos) + + # regression L1 loss + loss_bbox = self.loss_bbox( + bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos) + return loss_cls, loss_bbox, loss_iou diff --git a/projects/CO-DETR/codetr/co_roi_head.py b/projects/CO-DETR/codetr/co_roi_head.py new file mode 100644 index 00000000000..9aafb53bedd --- /dev/null +++ b/projects/CO-DETR/codetr/co_roi_head.py @@ -0,0 +1,108 @@ +from typing import List, Tuple + +import torch +from torch import Tensor + +from mmdet.models.roi_heads import StandardRoIHead +from mmdet.models.task_modules.samplers import SamplingResult +from mmdet.models.utils import unpack_gt_instances +from mmdet.registry import MODELS +from mmdet.structures import DetDataSample +from mmdet.structures.bbox import bbox2roi +from mmdet.utils import InstanceList + + +@MODELS.register_module() +class CoStandardRoIHead(StandardRoIHead): + + def loss(self, x: Tuple[Tensor], rpn_results_list: InstanceList, + batch_data_samples: List[DetDataSample]) -> dict: + max_proposal = 2000 + + assert len(rpn_results_list) == len(batch_data_samples) + outputs = unpack_gt_instances(batch_data_samples) + batch_gt_instances, batch_gt_instances_ignore, _ = outputs + + # assign gts and sample proposals + num_imgs = len(batch_data_samples) + sampling_results = [] + for i in range(num_imgs): + # rename rpn_results.bboxes to rpn_results.priors + rpn_results = rpn_results_list[i] + rpn_results.priors = rpn_results.pop('bboxes') + + assign_result = self.bbox_assigner.assign( + rpn_results, batch_gt_instances[i], + batch_gt_instances_ignore[i]) + sampling_result = self.bbox_sampler.sample( + assign_result, + rpn_results, + batch_gt_instances[i], + feats=[lvl_feat[i][None] for lvl_feat in x]) + sampling_results.append(sampling_result) + + losses = dict() + # bbox head forward and loss + if self.with_bbox: + bbox_results = self.bbox_loss(x, sampling_results) + losses.update(bbox_results['loss_bbox']) + + bbox_targets = bbox_results['bbox_targets'] + for res in sampling_results: + max_proposal = min(max_proposal, res.bboxes.shape[0]) + ori_coords = bbox2roi([res.bboxes for res in sampling_results]) + ori_proposals, ori_labels, \ + ori_bbox_targets, ori_bbox_feats = [], [], [], [] + for i in range(num_imgs): + idx = (ori_coords[:, 0] == i).nonzero().squeeze(1) + idx = idx[:max_proposal] + ori_proposal = ori_coords[idx][:, 1:].unsqueeze(0) + ori_label = bbox_targets[0][idx].unsqueeze(0) + ori_bbox_target = bbox_targets[2][idx].unsqueeze(0) + ori_bbox_feat = bbox_results['bbox_feats'].mean(-1).mean(-1) + ori_bbox_feat = ori_bbox_feat[idx].unsqueeze(0) + ori_proposals.append(ori_proposal) + ori_labels.append(ori_label) + ori_bbox_targets.append(ori_bbox_target) + ori_bbox_feats.append(ori_bbox_feat) + ori_coords = torch.cat(ori_proposals, dim=0) + ori_labels = torch.cat(ori_labels, dim=0) + ori_bbox_targets = torch.cat(ori_bbox_targets, dim=0) + ori_bbox_feats = torch.cat(ori_bbox_feats, dim=0) + pos_coords = (ori_coords, ori_labels, ori_bbox_targets, + ori_bbox_feats, 'rcnn') + losses.update(pos_coords=pos_coords) + + return losses + + def bbox_loss(self, x: Tuple[Tensor], + sampling_results: List[SamplingResult]) -> dict: + """Perform forward propagation and loss calculation of the bbox head on + the features of the upstream network. + + Args: + x (tuple[Tensor]): List of multi-level img features. + sampling_results (list["obj:`SamplingResult`]): Sampling results. + + Returns: + dict[str, Tensor]: Usually returns a dictionary with keys: + + - `cls_score` (Tensor): Classification scores. + - `bbox_pred` (Tensor): Box energies / deltas. + - `bbox_feats` (Tensor): Extract bbox RoI features. + - `loss_bbox` (dict): A dictionary of bbox loss components. + """ + rois = bbox2roi([res.priors for res in sampling_results]) + bbox_results = self._bbox_forward(x, rois) + + bbox_loss_and_target = self.bbox_head.loss_and_target( + cls_score=bbox_results['cls_score'], + bbox_pred=bbox_results['bbox_pred'], + rois=rois, + sampling_results=sampling_results, + rcnn_train_cfg=self.train_cfg) + + bbox_results.update(loss_bbox=bbox_loss_and_target['loss_bbox']) + # diff + bbox_results.update(bbox_targets=bbox_loss_and_target['bbox_targets']) + return bbox_results diff --git a/projects/CO-DETR/codetr/codetr.py b/projects/CO-DETR/codetr/codetr.py new file mode 100644 index 00000000000..82826f64107 --- /dev/null +++ b/projects/CO-DETR/codetr/codetr.py @@ -0,0 +1,320 @@ +import copy +from typing import Tuple, Union + +import torch +import torch.nn as nn +from torch import Tensor + +from mmdet.models.detectors.base import BaseDetector +from mmdet.registry import MODELS +from mmdet.structures import OptSampleList, SampleList +from mmdet.utils import InstanceList, OptConfigType, OptMultiConfig + + +@MODELS.register_module() +class CoDETR(BaseDetector): + + def __init__( + self, + backbone, + neck=None, + query_head=None, # detr head + rpn_head=None, # two-stage rpn + roi_head=[None], # two-stage + bbox_head=[None], # one-stage + train_cfg=[None, None], + test_cfg=[None, None], + # Control whether to consider positive samples + # from the auxiliary head as additional positive queries. + with_pos_coord=True, + use_lsj=True, + eval_module='detr', + # Evaluate the Nth head. + eval_index=0, + data_preprocessor: OptConfigType = None, + init_cfg: OptMultiConfig = None): + super(CoDETR, self).__init__( + data_preprocessor=data_preprocessor, init_cfg=init_cfg) + self.with_pos_coord = with_pos_coord + self.use_lsj = use_lsj + + assert eval_module in ['detr', 'one-stage', 'two-stage'] + self.eval_module = eval_module + + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + # Module index for evaluation + self.eval_index = eval_index + head_idx = 0 + if query_head is not None: + query_head.update(train_cfg=train_cfg[head_idx] if ( + train_cfg is not None and train_cfg[head_idx] is not None + ) else None) + query_head.update(test_cfg=test_cfg[head_idx]) + self.query_head = MODELS.build(query_head) + self.query_head.init_weights() + head_idx += 1 + + if rpn_head is not None: + rpn_train_cfg = train_cfg[head_idx].rpn if ( + train_cfg is not None + and train_cfg[head_idx] is not None) else None + rpn_head_ = rpn_head.copy() + rpn_head_.update( + train_cfg=rpn_train_cfg, test_cfg=test_cfg[head_idx].rpn) + self.rpn_head = MODELS.build(rpn_head_) + self.rpn_head.init_weights() + + self.roi_head = nn.ModuleList() + for i in range(len(roi_head)): + if roi_head[i]: + rcnn_train_cfg = train_cfg[i + head_idx].rcnn if ( + train_cfg + and train_cfg[i + head_idx] is not None) else None + roi_head[i].update(train_cfg=rcnn_train_cfg) + roi_head[i].update(test_cfg=test_cfg[i + head_idx].rcnn) + self.roi_head.append(MODELS.build(roi_head[i])) + self.roi_head[-1].init_weights() + + self.bbox_head = nn.ModuleList() + for i in range(len(bbox_head)): + if bbox_head[i]: + bbox_head[i].update( + train_cfg=train_cfg[i + head_idx + len(self.roi_head)] if ( + train_cfg and train_cfg[i + head_idx + + len(self.roi_head)] is not None + ) else None) + bbox_head[i].update(test_cfg=test_cfg[i + head_idx + + len(self.roi_head)]) + self.bbox_head.append(MODELS.build(bbox_head[i])) + self.bbox_head[-1].init_weights() + + self.head_idx = head_idx + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + @property + def with_rpn(self): + """bool: whether the detector has RPN""" + return hasattr(self, 'rpn_head') and self.rpn_head is not None + + @property + def with_query_head(self): + """bool: whether the detector has a RoI head""" + return hasattr(self, 'query_head') and self.query_head is not None + + @property + def with_roi_head(self): + """bool: whether the detector has a RoI head""" + return hasattr(self, 'roi_head') and self.roi_head is not None and len( + self.roi_head) > 0 + + @property + def with_shared_head(self): + """bool: whether the detector has a shared head in the RoI Head""" + return hasattr(self, 'roi_head') and self.roi_head[0].with_shared_head + + @property + def with_bbox(self): + """bool: whether the detector has a bbox head""" + return ((hasattr(self, 'roi_head') and self.roi_head is not None + and len(self.roi_head) > 0) + or (hasattr(self, 'bbox_head') and self.bbox_head is not None + and len(self.bbox_head) > 0)) + + def extract_feat(self, batch_inputs: Tensor) -> Tuple[Tensor]: + """Extract features. + + Args: + batch_inputs (Tensor): Image tensor, has shape (bs, dim, H, W). + + Returns: + tuple[Tensor]: Tuple of feature maps from neck. Each feature map + has shape (bs, dim, H, W). + """ + x = self.backbone(batch_inputs) + if self.with_neck: + x = self.neck(x) + return x + + def _forward(self, + batch_inputs: Tensor, + batch_data_samples: OptSampleList = None): + pass + + def loss(self, batch_inputs: Tensor, + batch_data_samples: SampleList) -> Union[dict, list]: + batch_input_shape = batch_data_samples[0].batch_input_shape + if self.use_lsj: + for data_samples in batch_data_samples: + img_metas = data_samples.metainfo + input_img_h, input_img_w = batch_input_shape + img_metas['img_shape'] = [input_img_h, input_img_w] + + x = self.extract_feat(batch_inputs) + + losses = dict() + + def upd_loss(losses, idx, weight=1): + new_losses = dict() + for k, v in losses.items(): + new_k = '{}{}'.format(k, idx) + if isinstance(v, list) or isinstance(v, tuple): + new_losses[new_k] = [i * weight for i in v] + else: + new_losses[new_k] = v * weight + return new_losses + + # DETR encoder and decoder forward + if self.with_query_head: + bbox_losses, x = self.query_head.loss(x, batch_data_samples) + losses.update(bbox_losses) + + # RPN forward and loss + if self.with_rpn: + proposal_cfg = self.train_cfg[self.head_idx].get( + 'rpn_proposal', self.test_cfg[self.head_idx].rpn) + + rpn_data_samples = copy.deepcopy(batch_data_samples) + # set cat_id of gt_labels to 0 in RPN + for data_sample in rpn_data_samples: + data_sample.gt_instances.labels = \ + torch.zeros_like(data_sample.gt_instances.labels) + + rpn_losses, proposal_list = self.rpn_head.loss_and_predict( + x, rpn_data_samples, proposal_cfg=proposal_cfg) + + # avoid get same name with roi_head loss + keys = rpn_losses.keys() + for key in list(keys): + if 'loss' in key and 'rpn' not in key: + rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) + + losses.update(rpn_losses) + else: + assert batch_data_samples[0].get('proposals', None) is not None + # use pre-defined proposals in InstanceData for the second stage + # to extract ROI features. + proposal_list = [ + data_sample.proposals for data_sample in batch_data_samples + ] + + positive_coords = [] + for i in range(len(self.roi_head)): + roi_losses = self.roi_head[i].loss(x, proposal_list, + batch_data_samples) + if self.with_pos_coord: + positive_coords.append(roi_losses.pop('pos_coords')) + else: + if 'pos_coords' in roi_losses.keys(): + roi_losses.pop('pos_coords') + roi_losses = upd_loss(roi_losses, idx=i) + losses.update(roi_losses) + + for i in range(len(self.bbox_head)): + bbox_losses = self.bbox_head[i].loss(x, batch_data_samples) + if self.with_pos_coord: + pos_coords = bbox_losses.pop('pos_coords') + positive_coords.append(pos_coords) + else: + if 'pos_coords' in bbox_losses.keys(): + bbox_losses.pop('pos_coords') + bbox_losses = upd_loss(bbox_losses, idx=i + len(self.roi_head)) + losses.update(bbox_losses) + + if self.with_pos_coord and len(positive_coords) > 0: + for i in range(len(positive_coords)): + bbox_losses = self.query_head.loss_aux(x, positive_coords[i], + i, batch_data_samples) + bbox_losses = upd_loss(bbox_losses, idx=i) + losses.update(bbox_losses) + + return losses + + def predict(self, + batch_inputs: Tensor, + batch_data_samples: SampleList, + rescale: bool = True) -> SampleList: + """Predict results from a batch of inputs and data samples with post- + processing. + + Args: + batch_inputs (Tensor): Inputs, has shape (bs, dim, H, W). + batch_data_samples (List[:obj:`DetDataSample`]): The batch + data samples. It usually includes information such + as `gt_instance` or `gt_panoptic_seg` or `gt_sem_seg`. + rescale (bool): Whether to rescale the results. + Defaults to True. + + Returns: + list[:obj:`DetDataSample`]: Detection results of the input images. + Each DetDataSample usually contain 'pred_instances'. And the + `pred_instances` usually contains following keys. + + - scores (Tensor): Classification scores, has a shape + (num_instance, ) + - labels (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes (Tensor): Has a shape (num_instances, 4), + the last dimension 4 arrange as (x1, y1, x2, y2). + """ + assert self.eval_module in ['detr', 'one-stage', 'two-stage'] + + if self.use_lsj: + for data_samples in batch_data_samples: + img_metas = data_samples.metainfo + input_img_h, input_img_w = img_metas['batch_input_shape'] + img_metas['img_shape'] = [input_img_h, input_img_w] + + img_feats = self.extract_feat(batch_inputs) + if self.with_bbox and self.eval_module == 'one-stage': + results_list = self.predict_bbox_head( + img_feats, batch_data_samples, rescale=rescale) + elif self.with_roi_head and self.eval_module == 'two-stage': + results_list = self.predict_roi_head( + img_feats, batch_data_samples, rescale=rescale) + else: + results_list = self.predict_query_head( + img_feats, batch_data_samples, rescale=rescale) + + batch_data_samples = self.add_pred_to_datasample( + batch_data_samples, results_list) + return batch_data_samples + + def predict_query_head(self, + mlvl_feats: Tuple[Tensor], + batch_data_samples: SampleList, + rescale: bool = True) -> InstanceList: + return self.query_head.predict( + mlvl_feats, batch_data_samples=batch_data_samples, rescale=rescale) + + def predict_roi_head(self, + mlvl_feats: Tuple[Tensor], + batch_data_samples: SampleList, + rescale: bool = True) -> InstanceList: + assert self.with_bbox, 'Bbox head must be implemented.' + if self.with_query_head: + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + results = self.query_head.forward(mlvl_feats, batch_img_metas) + mlvl_feats = results[-1] + rpn_results_list = self.rpn_head.predict( + mlvl_feats, batch_data_samples, rescale=False) + return self.roi_head[self.eval_index].predict( + mlvl_feats, rpn_results_list, batch_data_samples, rescale=rescale) + + def predict_bbox_head(self, + mlvl_feats: Tuple[Tensor], + batch_data_samples: SampleList, + rescale: bool = True) -> InstanceList: + assert self.with_bbox, 'Bbox head must be implemented.' + if self.with_query_head: + batch_img_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + results = self.query_head.forward(mlvl_feats, batch_img_metas) + mlvl_feats = results[-1] + return self.bbox_head[self.eval_index].predict( + mlvl_feats, batch_data_samples, rescale=rescale) diff --git a/projects/CO-DETR/codetr/transformer.py b/projects/CO-DETR/codetr/transformer.py new file mode 100644 index 00000000000..009f94a8bcc --- /dev/null +++ b/projects/CO-DETR/codetr/transformer.py @@ -0,0 +1,1376 @@ +import math +import warnings + +import torch +import torch.nn as nn +from mmcv.cnn import build_norm_layer +from mmcv.cnn.bricks.transformer import (BaseTransformerLayer, + TransformerLayerSequence, + build_transformer_layer_sequence) +from mmcv.ops import MultiScaleDeformableAttention +from mmengine.model import BaseModule +from mmengine.model.weight_init import xavier_init +from torch.nn.init import normal_ + +from mmdet.models.layers.transformer import inverse_sigmoid +from mmdet.registry import MODELS + +try: + from fairscale.nn.checkpoint import checkpoint_wrapper +except Exception: + checkpoint_wrapper = None + +# In order to save the cost and effort of reproduction, +# I did not refactor it into the style of mmdet 3.x DETR. + + +class Transformer(BaseModule): + """Implements the DETR transformer. + + Following the official DETR implementation, this module copy-paste + from torch.nn.Transformer with modifications: + + * positional encodings are passed in MultiheadAttention + * extra LN at the end of encoder is removed + * decoder returns a stack of activations from all decoding layers + + See `paper: End-to-End Object Detection with Transformers + `_ for details. + + Args: + encoder (`mmcv.ConfigDict` | Dict): Config of + TransformerEncoder. Defaults to None. + decoder ((`mmcv.ConfigDict` | Dict)): Config of + TransformerDecoder. Defaults to None + init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. + Defaults to None. + """ + + def __init__(self, encoder=None, decoder=None, init_cfg=None): + super(Transformer, self).__init__(init_cfg=init_cfg) + self.encoder = build_transformer_layer_sequence(encoder) + self.decoder = build_transformer_layer_sequence(decoder) + self.embed_dims = self.encoder.embed_dims + + def init_weights(self): + # follow the official DETR to init parameters + for m in self.modules(): + if hasattr(m, 'weight') and m.weight.dim() > 1: + xavier_init(m, distribution='uniform') + self._is_init = True + + def forward(self, x, mask, query_embed, pos_embed): + """Forward function for `Transformer`. + + Args: + x (Tensor): Input query with shape [bs, c, h, w] where + c = embed_dims. + mask (Tensor): The key_padding_mask used for encoder and decoder, + with shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, with shape + [num_query, c]. + pos_embed (Tensor): The positional encoding for encoder and + decoder, with the same shape as `x`. + + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + + - out_dec: Output from decoder. If return_intermediate_dec \ + is True output has shape [num_dec_layers, bs, + num_query, embed_dims], else has shape [1, bs, \ + num_query, embed_dims]. + - memory: Output results from encoder, with shape \ + [bs, embed_dims, h, w]. + """ + bs, c, h, w = x.shape + # use `view` instead of `flatten` for dynamically exporting to ONNX + x = x.view(bs, c, -1).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c] + pos_embed = pos_embed.view(bs, c, -1).permute(2, 0, 1) + query_embed = query_embed.unsqueeze(1).repeat( + 1, bs, 1) # [num_query, dim] -> [num_query, bs, dim] + mask = mask.view(bs, -1) # [bs, h, w] -> [bs, h*w] + memory = self.encoder( + query=x, + key=None, + value=None, + query_pos=pos_embed, + query_key_padding_mask=mask) + target = torch.zeros_like(query_embed) + # out_dec: [num_layers, num_query, bs, dim] + out_dec = self.decoder( + query=target, + key=memory, + value=memory, + key_pos=pos_embed, + query_pos=query_embed, + key_padding_mask=mask) + out_dec = out_dec.transpose(1, 2) + memory = memory.permute(1, 2, 0).reshape(bs, c, h, w) + return out_dec, memory + + +@MODELS.register_module(force=True) +class DeformableDetrTransformerDecoder(TransformerLayerSequence): + """Implements the decoder in DETR transformer. + + Args: + return_intermediate (bool): Whether to return intermediate outputs. + coder_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. + """ + + def __init__(self, *args, return_intermediate=False, **kwargs): + + super(DeformableDetrTransformerDecoder, self).__init__(*args, **kwargs) + self.return_intermediate = return_intermediate + + def forward(self, + query, + *args, + reference_points=None, + valid_ratios=None, + reg_branches=None, + **kwargs): + """Forward function for `TransformerDecoder`. + + Args: + query (Tensor): Input query with shape + `(num_query, bs, embed_dims)`. + reference_points (Tensor): The reference + points of offset. has shape + (bs, num_query, 4) when as_two_stage, + otherwise has shape ((bs, num_query, 2). + valid_ratios (Tensor): The radios of valid + points on the feature map, has shape + (bs, num_levels, 2) + reg_branch: (obj:`nn.ModuleList`): Used for + refining the regression results. Only would + be passed when with_box_refine is True, + otherwise would be passed a `None`. + + Returns: + Tensor: Results with shape [1, num_query, bs, embed_dims] when + return_intermediate is `False`, otherwise it has shape + [num_layers, num_query, bs, embed_dims]. + """ + output = query + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = reference_points[:, :, None] * \ + torch.cat([valid_ratios, valid_ratios], -1)[:, None] + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, None] * \ + valid_ratios[:, None] + output = layer( + output, + *args, + reference_points=reference_points_input, + **kwargs) + output = output.permute(1, 0, 2) + + if reg_branches is not None: + tmp = reg_branches[lid](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid( + reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[ + ..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + output = output.permute(1, 0, 2) + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append(reference_points) + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack( + intermediate_reference_points) + + return output, reference_points + + +@MODELS.register_module(force=True) +class DeformableDetrTransformer(Transformer): + """Implements the DeformableDETR transformer. + + Args: + as_two_stage (bool): Generate query from encoder features. + Default: False. + num_feature_levels (int): Number of feature maps from FPN: + Default: 4. + two_stage_num_proposals (int): Number of proposals when set + `as_two_stage` as True. Default: 300. + """ + + def __init__(self, + as_two_stage=False, + num_feature_levels=4, + two_stage_num_proposals=300, + **kwargs): + super(DeformableDetrTransformer, self).__init__(**kwargs) + self.as_two_stage = as_two_stage + self.num_feature_levels = num_feature_levels + self.two_stage_num_proposals = two_stage_num_proposals + self.embed_dims = self.encoder.embed_dims + self.init_layers() + + def init_layers(self): + """Initialize layers of the DeformableDetrTransformer.""" + self.level_embeds = nn.Parameter( + torch.Tensor(self.num_feature_levels, self.embed_dims)) + + if self.as_two_stage: + self.enc_output = nn.Linear(self.embed_dims, self.embed_dims) + self.enc_output_norm = nn.LayerNorm(self.embed_dims) + self.pos_trans = nn.Linear(self.embed_dims * 2, + self.embed_dims * 2) + self.pos_trans_norm = nn.LayerNorm(self.embed_dims * 2) + else: + self.reference_points = nn.Linear(self.embed_dims, 2) + + def init_weights(self): + """Initialize the transformer weights.""" + for p in self.parameters(): + if p.dim() > 1: + nn.init.xavier_uniform_(p) + for m in self.modules(): + if isinstance(m, MultiScaleDeformableAttention): + m.init_weights() + if not self.as_two_stage: + xavier_init(self.reference_points, distribution='uniform', bias=0.) + normal_(self.level_embeds) + + def gen_encoder_output_proposals(self, memory, memory_padding_mask, + spatial_shapes): + """Generate proposals from encoded memory. + + Args: + memory (Tensor) : The output of encoder, + has shape (bs, num_key, embed_dim). num_key is + equal the number of points on feature map from + all level. + memory_padding_mask (Tensor): Padding mask for memory. + has shape (bs, num_key). + spatial_shapes (Tensor): The shape of all feature maps. + has shape (num_level, 2). + + Returns: + tuple: A tuple of feature map and bbox prediction. + + - output_memory (Tensor): The input of decoder, \ + has shape (bs, num_key, embed_dim). num_key is \ + equal the number of points on feature map from \ + all levels. + - output_proposals (Tensor): The normalized proposal \ + after a inverse sigmoid, has shape \ + (bs, num_keys, 4). + """ + + N, S, C = memory.shape + proposals = [] + _cur = 0 + for lvl, (H, W) in enumerate(spatial_shapes): + mask_flatten_ = memory_padding_mask[:, _cur:(_cur + H * W)].view( + N, H, W, 1) + valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) + valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) + + grid_y, grid_x = torch.meshgrid( + torch.linspace( + 0, H - 1, H, dtype=torch.float32, device=memory.device), + torch.linspace( + 0, W - 1, W, dtype=torch.float32, device=memory.device)) + grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1) + + scale = torch.cat([valid_W.unsqueeze(-1), + valid_H.unsqueeze(-1)], 1).view(N, 1, 1, 2) + grid = (grid.unsqueeze(0).expand(N, -1, -1, -1) + 0.5) / scale + wh = torch.ones_like(grid) * 0.05 * (2.0**lvl) + proposal = torch.cat((grid, wh), -1).view(N, -1, 4) + proposals.append(proposal) + _cur += (H * W) + output_proposals = torch.cat(proposals, 1) + output_proposals_valid = ((output_proposals > 0.01) & + (output_proposals < 0.99)).all( + -1, keepdim=True) + output_proposals = torch.log(output_proposals / (1 - output_proposals)) + output_proposals = output_proposals.masked_fill( + memory_padding_mask.unsqueeze(-1), float('inf')) + output_proposals = output_proposals.masked_fill( + ~output_proposals_valid, float('inf')) + + output_memory = memory + output_memory = output_memory.masked_fill( + memory_padding_mask.unsqueeze(-1), float(0)) + output_memory = output_memory.masked_fill(~output_proposals_valid, + float(0)) + output_memory = self.enc_output_norm(self.enc_output(output_memory)) + return output_memory, output_proposals + + @staticmethod + def get_reference_points(spatial_shapes, valid_ratios, device): + """Get the reference points used in decoder. + + Args: + spatial_shapes (Tensor): The shape of all + feature maps, has shape (num_level, 2). + valid_ratios (Tensor): The radios of valid + points on the feature map, has shape + (bs, num_levels, 2) + device (obj:`device`): The device where + reference_points should be. + + Returns: + Tensor: reference points used in decoder, has \ + shape (bs, num_keys, num_levels, 2). + """ + reference_points_list = [] + for lvl, (H, W) in enumerate(spatial_shapes): + ref_y, ref_x = torch.meshgrid( + torch.linspace( + 0.5, H - 0.5, H, dtype=torch.float32, device=device), + torch.linspace( + 0.5, W - 0.5, W, dtype=torch.float32, device=device)) + ref_y = ref_y.reshape(-1)[None] / ( + valid_ratios[:, None, lvl, 1] * H) + ref_x = ref_x.reshape(-1)[None] / ( + valid_ratios[:, None, lvl, 0] * W) + ref = torch.stack((ref_x, ref_y), -1) + reference_points_list.append(ref) + reference_points = torch.cat(reference_points_list, 1) + reference_points = reference_points[:, :, None] * valid_ratios[:, None] + return reference_points + + def get_valid_ratio(self, mask): + """Get the valid radios of feature maps of all level.""" + _, H, W = mask.shape + valid_H = torch.sum(~mask[:, :, 0], 1) + valid_W = torch.sum(~mask[:, 0, :], 1) + valid_ratio_h = valid_H.float() / H + valid_ratio_w = valid_W.float() / W + valid_ratio = torch.stack([valid_ratio_w, valid_ratio_h], -1) + return valid_ratio + + def get_proposal_pos_embed(self, + proposals, + num_pos_feats=128, + temperature=10000): + """Get the position embedding of proposal.""" + scale = 2 * math.pi + dim_t = torch.arange( + num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), + dim=4).flatten(2) + return pos + + def forward(self, + mlvl_feats, + mlvl_masks, + query_embed, + mlvl_pos_embeds, + reg_branches=None, + cls_branches=None, + **kwargs): + """Forward function for `Transformer`. + + Args: + mlvl_feats (list(Tensor)): Input queries from + different level. Each element has shape + [bs, embed_dims, h, w]. + mlvl_masks (list(Tensor)): The key_padding_mask from + different level used for encoder and decoder, + each element has shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, + with shape [num_query, c]. + mlvl_pos_embeds (list(Tensor)): The positional encoding + of feats from different level, has the shape + [bs, embed_dims, h, w]. + reg_branches (obj:`nn.ModuleList`): Regression heads for + feature maps from each decoder layer. Only would + be passed when + `with_box_refine` is True. Default to None. + cls_branches (obj:`nn.ModuleList`): Classification heads + for feature maps from each decoder layer. Only would + be passed when `as_two_stage` + is True. Default to None. + + + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + + - inter_states: Outputs from decoder. If + return_intermediate_dec is True output has shape \ + (num_dec_layers, bs, num_query, embed_dims), else has \ + shape (1, bs, num_query, embed_dims). + - init_reference_out: The initial value of reference \ + points, has shape (bs, num_queries, 4). + - inter_references_out: The internal value of reference \ + points in decoder, has shape \ + (num_dec_layers, bs,num_query, embed_dims) + - enc_outputs_class: The classification score of \ + proposals generated from \ + encoder's feature maps, has shape \ + (batch, h*w, num_classes). \ + Only would be returned when `as_two_stage` is True, \ + otherwise None. + - enc_outputs_coord_unact: The regression results \ + generated from encoder's feature maps., has shape \ + (batch, h*w, 4). Only would \ + be returned when `as_two_stage` is True, \ + otherwise None. + """ + assert self.as_two_stage or query_embed is not None + + feat_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate( + zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): + bs, c, h, w = feat.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + feat = feat.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + feat_flatten.append(feat) + mask_flatten.append(mask) + feat_flatten = torch.cat(feat_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=feat_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack( + [self.get_valid_ratio(m) for m in mlvl_masks], 1) + + reference_points = \ + self.get_reference_points(spatial_shapes, + valid_ratios, + device=feat.device) + + feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) + lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute( + 1, 0, 2) # (H*W, bs, embed_dims) + memory = self.encoder( + query=feat_flatten, + key=None, + value=None, + query_pos=lvl_pos_embed_flatten, + query_key_padding_mask=mask_flatten, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + **kwargs) + + memory = memory.permute(1, 0, 2) + bs, _, c = memory.shape + if self.as_two_stage: + output_memory, output_proposals = \ + self.gen_encoder_output_proposals( + memory, mask_flatten, spatial_shapes) + enc_outputs_class = cls_branches[self.decoder.num_layers]( + output_memory) + enc_outputs_coord_unact = \ + reg_branches[ + self.decoder.num_layers](output_memory) + output_proposals + + topk = self.two_stage_num_proposals + # We only use the first channel in enc_outputs_class as foreground, + # the other (num_classes - 1) channels are actually not used. + # Its targets are set to be 0s, which indicates the first + # class (foreground) because we use [0, num_classes - 1] to + # indicate class labels, background class is indicated by + # num_classes (similar convention in RPN). + # See https://github.com/open-mmlab/mmdetection/blob/master/mmdet/models/dense_heads/deformable_detr_head.py#L241 # noqa + # This follows the official implementation of Deformable DETR. + topk_proposals = torch.topk( + enc_outputs_class[..., 0], topk, dim=1)[1] + topk_coords_unact = torch.gather( + enc_outputs_coord_unact, 1, + topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + init_reference_out = reference_points + pos_trans_out = self.pos_trans_norm( + self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) + query_pos, query = torch.split(pos_trans_out, c, dim=2) + else: + query_pos, query = torch.split(query_embed, c, dim=1) + query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) + query = query.unsqueeze(0).expand(bs, -1, -1) + reference_points = self.reference_points(query_pos).sigmoid() + init_reference_out = reference_points + + # decoder + query = query.permute(1, 0, 2) + memory = memory.permute(1, 0, 2) + query_pos = query_pos.permute(1, 0, 2) + inter_states, inter_references = self.decoder( + query=query, + key=None, + value=memory, + query_pos=query_pos, + key_padding_mask=mask_flatten, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=reg_branches, + **kwargs) + + inter_references_out = inter_references + if self.as_two_stage: + return inter_states, init_reference_out,\ + inter_references_out, enc_outputs_class,\ + enc_outputs_coord_unact + return inter_states, init_reference_out, \ + inter_references_out, None, None + + +@MODELS.register_module() +class CoDeformableDetrTransformerDecoder(TransformerLayerSequence): + """Implements the decoder in DETR transformer. + + Args: + return_intermediate (bool): Whether to return intermediate outputs. + coder_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. + """ + + def __init__(self, + *args, + return_intermediate=False, + look_forward_twice=False, + **kwargs): + + super(CoDeformableDetrTransformerDecoder, + self).__init__(*args, **kwargs) + self.return_intermediate = return_intermediate + self.look_forward_twice = look_forward_twice + + def forward(self, + query, + *args, + reference_points=None, + valid_ratios=None, + reg_branches=None, + **kwargs): + """Forward function for `TransformerDecoder`. + + Args: + query (Tensor): Input query with shape + `(num_query, bs, embed_dims)`. + reference_points (Tensor): The reference + points of offset. has shape + (bs, num_query, 4) when as_two_stage, + otherwise has shape ((bs, num_query, 2). + valid_ratios (Tensor): The radios of valid + points on the feature map, has shape + (bs, num_levels, 2) + reg_branch: (obj:`nn.ModuleList`): Used for + refining the regression results. Only would + be passed when with_box_refine is True, + otherwise would be passed a `None`. + + Returns: + Tensor: Results with shape [1, num_query, bs, embed_dims] when + return_intermediate is `False`, otherwise it has shape + [num_layers, num_query, bs, embed_dims]. + """ + output = query + intermediate = [] + intermediate_reference_points = [] + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = reference_points[:, :, None] * \ + torch.cat([valid_ratios, valid_ratios], -1)[:, None] + else: + assert reference_points.shape[-1] == 2 + reference_points_input = reference_points[:, :, None] * \ + valid_ratios[:, None] + output = layer( + output, + *args, + reference_points=reference_points_input, + **kwargs) + output = output.permute(1, 0, 2) + + if reg_branches is not None: + tmp = reg_branches[lid](output) + if reference_points.shape[-1] == 4: + new_reference_points = tmp + inverse_sigmoid( + reference_points) + new_reference_points = new_reference_points.sigmoid() + else: + assert reference_points.shape[-1] == 2 + new_reference_points = tmp + new_reference_points[..., :2] = tmp[ + ..., :2] + inverse_sigmoid(reference_points) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + output = output.permute(1, 0, 2) + if self.return_intermediate: + intermediate.append(output) + intermediate_reference_points.append( + new_reference_points if self. + look_forward_twice else reference_points) + if self.return_intermediate: + return torch.stack(intermediate), torch.stack( + intermediate_reference_points) + + return output, reference_points + + +@MODELS.register_module() +class CoDeformableDetrTransformer(DeformableDetrTransformer): + + def __init__(self, + mixed_selection=True, + with_pos_coord=True, + with_coord_feat=True, + num_co_heads=1, + **kwargs): + self.mixed_selection = mixed_selection + self.with_pos_coord = with_pos_coord + self.with_coord_feat = with_coord_feat + self.num_co_heads = num_co_heads + super(CoDeformableDetrTransformer, self).__init__(**kwargs) + self._init_layers() + + def _init_layers(self): + """Initialize layers of the CoDeformableDetrTransformer.""" + if self.with_pos_coord: + if self.num_co_heads > 0: + # bug: this code should be 'self.head_pos_embed = + # nn.Embedding(self.num_co_heads, self.embed_dims)', + # we keep this bug for reproducing our results with ResNet-50. + # You can fix this bug when reproducing results with + # swin transformer. + self.head_pos_embed = nn.Embedding(self.num_co_heads, 1, 1, + self.embed_dims) + self.aux_pos_trans = nn.ModuleList() + self.aux_pos_trans_norm = nn.ModuleList() + self.pos_feats_trans = nn.ModuleList() + self.pos_feats_norm = nn.ModuleList() + for i in range(self.num_co_heads): + self.aux_pos_trans.append( + nn.Linear(self.embed_dims * 2, self.embed_dims * 2)) + self.aux_pos_trans_norm.append( + nn.LayerNorm(self.embed_dims * 2)) + if self.with_coord_feat: + self.pos_feats_trans.append( + nn.Linear(self.embed_dims, self.embed_dims)) + self.pos_feats_norm.append( + nn.LayerNorm(self.embed_dims)) + + def get_proposal_pos_embed(self, + proposals, + num_pos_feats=128, + temperature=10000): + """Get the position embedding of proposal.""" + num_pos_feats = self.embed_dims // 2 + scale = 2 * math.pi + dim_t = torch.arange( + num_pos_feats, dtype=torch.float32, device=proposals.device) + dim_t = temperature**(2 * (dim_t // 2) / num_pos_feats) + # N, L, 4 + proposals = proposals.sigmoid() * scale + # N, L, 4, 128 + pos = proposals[:, :, :, None] / dim_t + # N, L, 4, 64, 2 + pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), + dim=4).flatten(2) + return pos + + def forward(self, + mlvl_feats, + mlvl_masks, + query_embed, + mlvl_pos_embeds, + reg_branches=None, + cls_branches=None, + return_encoder_output=False, + attn_masks=None, + **kwargs): + """Forward function for `Transformer`. + + Args: + mlvl_feats (list(Tensor)): Input queries from + different level. Each element has shape + [bs, embed_dims, h, w]. + mlvl_masks (list(Tensor)): The key_padding_mask from + different level used for encoder and decoder, + each element has shape [bs, h, w]. + query_embed (Tensor): The query embedding for decoder, + with shape [num_query, c]. + mlvl_pos_embeds (list(Tensor)): The positional encoding + of feats from different level, has the shape + [bs, embed_dims, h, w]. + reg_branches (obj:`nn.ModuleList`): Regression heads for + feature maps from each decoder layer. Only would + be passed when + `with_box_refine` is True. Default to None. + cls_branches (obj:`nn.ModuleList`): Classification heads + for feature maps from each decoder layer. Only would + be passed when `as_two_stage` + is True. Default to None. + + + Returns: + tuple[Tensor]: results of decoder containing the following tensor. + + - inter_states: Outputs from decoder. If + return_intermediate_dec is True output has shape \ + (num_dec_layers, bs, num_query, embed_dims), else has \ + shape (1, bs, num_query, embed_dims). + - init_reference_out: The initial value of reference \ + points, has shape (bs, num_queries, 4). + - inter_references_out: The internal value of reference \ + points in decoder, has shape \ + (num_dec_layers, bs,num_query, embed_dims) + - enc_outputs_class: The classification score of \ + proposals generated from \ + encoder's feature maps, has shape \ + (batch, h*w, num_classes). \ + Only would be returned when `as_two_stage` is True, \ + otherwise None. + - enc_outputs_coord_unact: The regression results \ + generated from encoder's feature maps., has shape \ + (batch, h*w, 4). Only would \ + be returned when `as_two_stage` is True, \ + otherwise None. + """ + assert self.as_two_stage or query_embed is not None + + feat_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate( + zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): + bs, c, h, w = feat.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + feat = feat.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + feat_flatten.append(feat) + mask_flatten.append(mask) + feat_flatten = torch.cat(feat_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=feat_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack( + [self.get_valid_ratio(m) for m in mlvl_masks], 1) + + reference_points = \ + self.get_reference_points(spatial_shapes, + valid_ratios, + device=feat.device) + + feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) + lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute( + 1, 0, 2) # (H*W, bs, embed_dims) + memory = self.encoder( + query=feat_flatten, + key=None, + value=None, + query_pos=lvl_pos_embed_flatten, + query_key_padding_mask=mask_flatten, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + **kwargs) + + memory = memory.permute(1, 0, 2) + bs, _, c = memory.shape + if self.as_two_stage: + output_memory, output_proposals = \ + self.gen_encoder_output_proposals( + memory, mask_flatten, spatial_shapes) + enc_outputs_class = cls_branches[self.decoder.num_layers]( + output_memory) + enc_outputs_coord_unact = \ + reg_branches[ + self.decoder.num_layers](output_memory) + output_proposals + + topk = self.two_stage_num_proposals + topk = query_embed.shape[0] + topk_proposals = torch.topk( + enc_outputs_class[..., 0], topk, dim=1)[1] + topk_coords_unact = torch.gather( + enc_outputs_coord_unact, 1, + topk_proposals.unsqueeze(-1).repeat(1, 1, 4)) + topk_coords_unact = topk_coords_unact.detach() + reference_points = topk_coords_unact.sigmoid() + init_reference_out = reference_points + pos_trans_out = self.pos_trans_norm( + self.pos_trans(self.get_proposal_pos_embed(topk_coords_unact))) + + if not self.mixed_selection: + query_pos, query = torch.split(pos_trans_out, c, dim=2) + else: + # query_embed here is the content embed for deformable DETR + query = query_embed.unsqueeze(0).expand(bs, -1, -1) + query_pos, _ = torch.split(pos_trans_out, c, dim=2) + else: + query_pos, query = torch.split(query_embed, c, dim=1) + query_pos = query_pos.unsqueeze(0).expand(bs, -1, -1) + query = query.unsqueeze(0).expand(bs, -1, -1) + reference_points = self.reference_points(query_pos).sigmoid() + init_reference_out = reference_points + + # decoder + query = query.permute(1, 0, 2) + memory = memory.permute(1, 0, 2) + query_pos = query_pos.permute(1, 0, 2) + inter_states, inter_references = self.decoder( + query=query, + key=None, + value=memory, + query_pos=query_pos, + key_padding_mask=mask_flatten, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=reg_branches, + attn_masks=attn_masks, + **kwargs) + + inter_references_out = inter_references + if self.as_two_stage: + if return_encoder_output: + return inter_states, init_reference_out,\ + inter_references_out, enc_outputs_class,\ + enc_outputs_coord_unact, memory + return inter_states, init_reference_out,\ + inter_references_out, enc_outputs_class,\ + enc_outputs_coord_unact + if return_encoder_output: + return inter_states, init_reference_out, \ + inter_references_out, None, None, memory + return inter_states, init_reference_out, \ + inter_references_out, None, None + + def forward_aux(self, + mlvl_feats, + mlvl_masks, + query_embed, + mlvl_pos_embeds, + pos_anchors, + pos_feats=None, + reg_branches=None, + cls_branches=None, + return_encoder_output=False, + attn_masks=None, + head_idx=0, + **kwargs): + feat_flatten = [] + mask_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate( + zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): + bs, c, h, w = feat.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + feat = feat.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + feat_flatten.append(feat) + mask_flatten.append(mask) + feat_flatten = torch.cat(feat_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=feat_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack( + [self.get_valid_ratio(m) for m in mlvl_masks], 1) + + feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) + + memory = feat_flatten + memory = memory.permute(1, 0, 2) + bs, _, c = memory.shape + + topk_coords_unact = inverse_sigmoid(pos_anchors) + reference_points = pos_anchors + init_reference_out = reference_points + if self.num_co_heads > 0: + pos_trans_out = self.aux_pos_trans_norm[head_idx]( + self.aux_pos_trans[head_idx]( + self.get_proposal_pos_embed(topk_coords_unact))) + query_pos, query = torch.split(pos_trans_out, c, dim=2) + if self.with_coord_feat: + query = query + self.pos_feats_norm[head_idx]( + self.pos_feats_trans[head_idx](pos_feats)) + query_pos = query_pos + self.head_pos_embed.weight[head_idx] + + # decoder + query = query.permute(1, 0, 2) + memory = memory.permute(1, 0, 2) + query_pos = query_pos.permute(1, 0, 2) + inter_states, inter_references = self.decoder( + query=query, + key=None, + value=memory, + query_pos=query_pos, + key_padding_mask=mask_flatten, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=reg_branches, + attn_masks=attn_masks, + **kwargs) + + inter_references_out = inter_references + return inter_states, init_reference_out, \ + inter_references_out + + +def build_MLP(input_dim, hidden_dim, output_dim, num_layers): + assert num_layers > 1, \ + f'num_layers should be greater than 1 but got {num_layers}' + h = [hidden_dim] * (num_layers - 1) + layers = list() + for n, k in zip([input_dim] + h[:-1], h): + layers.extend((nn.Linear(n, k), nn.ReLU())) + # Note that the relu func of MLP in original DETR repo is set + # 'inplace=False', however the ReLU cfg of FFN in mmdet is set + # 'inplace=True' by default. + layers.append(nn.Linear(hidden_dim, output_dim)) + return nn.Sequential(*layers) + + +@MODELS.register_module() +class DinoTransformerDecoder(DeformableDetrTransformerDecoder): + + def __init__(self, *args, **kwargs): + super(DinoTransformerDecoder, self).__init__(*args, **kwargs) + self._init_layers() + + def _init_layers(self): + self.ref_point_head = build_MLP(self.embed_dims * 2, self.embed_dims, + self.embed_dims, 2) + self.norm = nn.LayerNorm(self.embed_dims) + + @staticmethod + def gen_sineembed_for_position(pos_tensor, pos_feat): + # n_query, bs, _ = pos_tensor.size() + # sineembed_tensor = torch.zeros(n_query, bs, 256) + scale = 2 * math.pi + dim_t = torch.arange( + pos_feat, dtype=torch.float32, device=pos_tensor.device) + dim_t = 10000**(2 * (dim_t // 2) / pos_feat) + x_embed = pos_tensor[:, :, 0] * scale + y_embed = pos_tensor[:, :, 1] * scale + pos_x = x_embed[:, :, None] / dim_t + pos_y = y_embed[:, :, None] / dim_t + pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), + dim=3).flatten(2) + pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), + dim=3).flatten(2) + if pos_tensor.size(-1) == 2: + pos = torch.cat((pos_y, pos_x), dim=2) + elif pos_tensor.size(-1) == 4: + w_embed = pos_tensor[:, :, 2] * scale + pos_w = w_embed[:, :, None] / dim_t + pos_w = torch.stack( + (pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), + dim=3).flatten(2) + + h_embed = pos_tensor[:, :, 3] * scale + pos_h = h_embed[:, :, None] / dim_t + pos_h = torch.stack( + (pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), + dim=3).flatten(2) + + pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2) + else: + raise ValueError('Unknown pos_tensor shape(-1):{}'.format( + pos_tensor.size(-1))) + return pos + + def forward(self, + query, + *args, + reference_points=None, + valid_ratios=None, + reg_branches=None, + **kwargs): + output = query + intermediate = [] + intermediate_reference_points = [reference_points] + for lid, layer in enumerate(self.layers): + if reference_points.shape[-1] == 4: + reference_points_input = \ + reference_points[:, :, None] * torch.cat( + [valid_ratios, valid_ratios], -1)[:, None] + else: + assert reference_points.shape[-1] == 2 + reference_points_input = \ + reference_points[:, :, None] * valid_ratios[:, None] + + query_sine_embed = self.gen_sineembed_for_position( + reference_points_input[:, :, 0, :], self.embed_dims // 2) + query_pos = self.ref_point_head(query_sine_embed) + + query_pos = query_pos.permute(1, 0, 2) + output = layer( + output, + *args, + query_pos=query_pos, + reference_points=reference_points_input, + **kwargs) + output = output.permute(1, 0, 2) + + if reg_branches is not None: + tmp = reg_branches[lid](output) + assert reference_points.shape[-1] == 4 + new_reference_points = tmp + inverse_sigmoid( + reference_points, eps=1e-3) + new_reference_points = new_reference_points.sigmoid() + reference_points = new_reference_points.detach() + + output = output.permute(1, 0, 2) + if self.return_intermediate: + intermediate.append(self.norm(output)) + intermediate_reference_points.append(new_reference_points) + # NOTE this is for the "Look Forward Twice" module, + # in the DeformDETR, reference_points was appended. + + if self.return_intermediate: + return torch.stack(intermediate), torch.stack( + intermediate_reference_points) + + return output, reference_points + + +@MODELS.register_module() +class CoDinoTransformer(CoDeformableDetrTransformer): + + def __init__(self, *args, **kwargs): + super(CoDinoTransformer, self).__init__(*args, **kwargs) + + def init_layers(self): + """Initialize layers of the DinoTransformer.""" + self.level_embeds = nn.Parameter( + torch.Tensor(self.num_feature_levels, self.embed_dims)) + self.enc_output = nn.Linear(self.embed_dims, self.embed_dims) + self.enc_output_norm = nn.LayerNorm(self.embed_dims) + self.query_embed = nn.Embedding(self.two_stage_num_proposals, + self.embed_dims) + + def _init_layers(self): + if self.with_pos_coord: + if self.num_co_heads > 0: + self.aux_pos_trans = nn.ModuleList() + self.aux_pos_trans_norm = nn.ModuleList() + self.pos_feats_trans = nn.ModuleList() + self.pos_feats_norm = nn.ModuleList() + for i in range(self.num_co_heads): + self.aux_pos_trans.append( + nn.Linear(self.embed_dims * 2, self.embed_dims)) + self.aux_pos_trans_norm.append( + nn.LayerNorm(self.embed_dims)) + if self.with_coord_feat: + self.pos_feats_trans.append( + nn.Linear(self.embed_dims, self.embed_dims)) + self.pos_feats_norm.append( + nn.LayerNorm(self.embed_dims)) + + def init_weights(self): + super().init_weights() + nn.init.normal_(self.query_embed.weight.data) + + def forward(self, + mlvl_feats, + mlvl_masks, + query_embed, + mlvl_pos_embeds, + dn_label_query, + dn_bbox_query, + attn_mask, + reg_branches=None, + cls_branches=None, + **kwargs): + assert self.as_two_stage and query_embed is None, \ + 'as_two_stage must be True for DINO' + + feat_flatten = [] + mask_flatten = [] + lvl_pos_embed_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate( + zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): + bs, c, h, w = feat.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + feat = feat.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + pos_embed = pos_embed.flatten(2).transpose(1, 2) + lvl_pos_embed = pos_embed + self.level_embeds[lvl].view(1, 1, -1) + lvl_pos_embed_flatten.append(lvl_pos_embed) + feat_flatten.append(feat) + mask_flatten.append(mask) + feat_flatten = torch.cat(feat_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=feat_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack( + [self.get_valid_ratio(m) for m in mlvl_masks], 1) + + reference_points = self.get_reference_points( + spatial_shapes, valid_ratios, device=feat.device) + + feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) + lvl_pos_embed_flatten = lvl_pos_embed_flatten.permute( + 1, 0, 2) # (H*W, bs, embed_dims) + memory = self.encoder( + query=feat_flatten, + key=None, + value=None, + query_pos=lvl_pos_embed_flatten, + query_key_padding_mask=mask_flatten, + spatial_shapes=spatial_shapes, + reference_points=reference_points, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + **kwargs) + memory = memory.permute(1, 0, 2) + bs, _, c = memory.shape + + output_memory, output_proposals = self.gen_encoder_output_proposals( + memory, mask_flatten, spatial_shapes) + enc_outputs_class = cls_branches[self.decoder.num_layers]( + output_memory) + enc_outputs_coord_unact = reg_branches[self.decoder.num_layers]( + output_memory) + output_proposals + cls_out_features = cls_branches[self.decoder.num_layers].out_features + topk = self.two_stage_num_proposals + # NOTE In DeformDETR, enc_outputs_class[..., 0] is used for topk + topk_indices = torch.topk(enc_outputs_class.max(-1)[0], topk, dim=1)[1] + + topk_score = torch.gather( + enc_outputs_class, 1, + topk_indices.unsqueeze(-1).repeat(1, 1, cls_out_features)) + topk_coords_unact = torch.gather( + enc_outputs_coord_unact, 1, + topk_indices.unsqueeze(-1).repeat(1, 1, 4)) + topk_anchor = topk_coords_unact.sigmoid() + topk_coords_unact = topk_coords_unact.detach() + + query = self.query_embed.weight[:, None, :].repeat(1, bs, + 1).transpose(0, 1) + # NOTE the query_embed here is not spatial query as in DETR. + # It is actually content query, which is named tgt in other + # DETR-like models + if dn_label_query is not None: + query = torch.cat([dn_label_query, query], dim=1) + if dn_bbox_query is not None: + reference_points = torch.cat([dn_bbox_query, topk_coords_unact], + dim=1) + else: + reference_points = topk_coords_unact + reference_points = reference_points.sigmoid() + # decoder + query = query.permute(1, 0, 2) + memory = memory.permute(1, 0, 2) + inter_states, inter_references = self.decoder( + query=query, + key=None, + value=memory, + attn_masks=attn_mask, + key_padding_mask=mask_flatten, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=reg_branches, + **kwargs) + + inter_references_out = inter_references + + return inter_states, inter_references_out, \ + topk_score, topk_anchor, memory + + def forward_aux(self, + mlvl_feats, + mlvl_masks, + query_embed, + mlvl_pos_embeds, + pos_anchors, + pos_feats=None, + reg_branches=None, + cls_branches=None, + return_encoder_output=False, + attn_masks=None, + head_idx=0, + **kwargs): + feat_flatten = [] + mask_flatten = [] + spatial_shapes = [] + for lvl, (feat, mask, pos_embed) in enumerate( + zip(mlvl_feats, mlvl_masks, mlvl_pos_embeds)): + bs, c, h, w = feat.shape + spatial_shape = (h, w) + spatial_shapes.append(spatial_shape) + feat = feat.flatten(2).transpose(1, 2) + mask = mask.flatten(1) + feat_flatten.append(feat) + mask_flatten.append(mask) + feat_flatten = torch.cat(feat_flatten, 1) + mask_flatten = torch.cat(mask_flatten, 1) + spatial_shapes = torch.as_tensor( + spatial_shapes, dtype=torch.long, device=feat_flatten.device) + level_start_index = torch.cat((spatial_shapes.new_zeros( + (1, )), spatial_shapes.prod(1).cumsum(0)[:-1])) + valid_ratios = torch.stack( + [self.get_valid_ratio(m) for m in mlvl_masks], 1) + + feat_flatten = feat_flatten.permute(1, 0, 2) # (H*W, bs, embed_dims) + + memory = feat_flatten + memory = memory.permute(1, 0, 2) + bs, _, c = memory.shape + + topk_coords_unact = inverse_sigmoid(pos_anchors) + reference_points = pos_anchors + if self.num_co_heads > 0: + pos_trans_out = self.aux_pos_trans_norm[head_idx]( + self.aux_pos_trans[head_idx]( + self.get_proposal_pos_embed(topk_coords_unact))) + query = pos_trans_out + if self.with_coord_feat: + query = query + self.pos_feats_norm[head_idx]( + self.pos_feats_trans[head_idx](pos_feats)) + + # decoder + query = query.permute(1, 0, 2) + memory = memory.permute(1, 0, 2) + inter_states, inter_references = self.decoder( + query=query, + key=None, + value=memory, + attn_masks=None, + key_padding_mask=mask_flatten, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index, + valid_ratios=valid_ratios, + reg_branches=reg_branches, + **kwargs) + + inter_references_out = inter_references + + return inter_states, inter_references_out + + +@MODELS.register_module() +class DetrTransformerEncoder(TransformerLayerSequence): + """TransformerEncoder of DETR. + + Args: + post_norm_cfg (dict): Config of last normalization layer. Default: + `LN`. Only used when `self.pre_norm` is `True` + """ + + def __init__(self, + *args, + post_norm_cfg=dict(type='LN'), + with_cp=-1, + **kwargs): + super(DetrTransformerEncoder, self).__init__(*args, **kwargs) + if post_norm_cfg is not None: + self.post_norm = build_norm_layer( + post_norm_cfg, self.embed_dims)[1] if self.pre_norm else None + else: + assert not self.pre_norm, f'Use prenorm in ' \ + f'{self.__class__.__name__},' \ + f'Please specify post_norm_cfg' + self.post_norm = None + self.with_cp = with_cp + if self.with_cp > 0: + if checkpoint_wrapper is None: + warnings.warn('If you want to reduce GPU memory usage, \ + please install fairscale by executing the \ + following command: pip install fairscale.') + return + for i in range(self.with_cp): + self.layers[i] = checkpoint_wrapper(self.layers[i]) + + +@MODELS.register_module() +class DetrTransformerDecoderLayer(BaseTransformerLayer): + """Implements decoder layer in DETR transformer. + + Args: + attn_cfgs (list[`mmcv.ConfigDict`] | list[dict] | dict )): + Configs for self_attention or cross_attention, the order + should be consistent with it in `operation_order`. If it is + a dict, it would be expand to the number of attention in + `operation_order`. + feedforward_channels (int): The hidden dimension for FFNs. + ffn_dropout (float): Probability of an element to be zeroed + in ffn. Default 0.0. + operation_order (tuple[str]): The execution order of operation + in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm'). + Default:None + act_cfg (dict): The activation config for FFNs. Default: `LN` + norm_cfg (dict): Config dict for normalization layer. + Default: `LN`. + ffn_num_fcs (int): The number of fully-connected layers in FFNs. + Default:2. + """ + + def __init__(self, + attn_cfgs, + feedforward_channels, + ffn_dropout=0.0, + operation_order=None, + act_cfg=dict(type='ReLU', inplace=True), + norm_cfg=dict(type='LN'), + ffn_num_fcs=2, + **kwargs): + super(DetrTransformerDecoderLayer, self).__init__( + attn_cfgs=attn_cfgs, + feedforward_channels=feedforward_channels, + ffn_dropout=ffn_dropout, + operation_order=operation_order, + act_cfg=act_cfg, + norm_cfg=norm_cfg, + ffn_num_fcs=ffn_num_fcs, + **kwargs) + assert len(operation_order) == 6 + assert set(operation_order) == set( + ['self_attn', 'norm', 'cross_attn', 'ffn']) diff --git a/projects/CO-DETR/configs/codino/co_dino_5scale_r50_8xb2_1x_coco.py b/projects/CO-DETR/configs/codino/co_dino_5scale_r50_8xb2_1x_coco.py new file mode 100644 index 00000000000..1a413043766 --- /dev/null +++ b/projects/CO-DETR/configs/codino/co_dino_5scale_r50_8xb2_1x_coco.py @@ -0,0 +1,68 @@ +_base_ = './co_dino_5scale_r50_lsj_8xb2_1x_coco.py' + +model = dict( + use_lsj=False, data_preprocessor=dict(pad_mask=False, batch_augments=None)) + +# train_pipeline, NOTE the img_scale and the Pad's size_divisor is different +# from the default setting in mmdet. +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=_base_.backend_args), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RandomFlip', prob=0.5), + dict( + type='RandomChoice', + transforms=[ + [ + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ], + [ + dict( + type='RandomChoiceResize', + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='RandomChoiceResize', + scales=[(480, 1333), (512, 1333), (544, 1333), (576, 1333), + (608, 1333), (640, 1333), (672, 1333), (704, 1333), + (736, 1333), (768, 1333), (800, 1333)], + keep_ratio=True) + ] + ]), + dict(type='PackDetInputs') +] + +train_dataloader = dict( + dataset=dict( + _delete_=True, + type=_base_.dataset_type, + data_root=_base_.data_root, + ann_file='annotations/instances_train2017.json', + data_prefix=dict(img='train2017/'), + filter_cfg=dict(filter_empty_gt=False, min_size=32), + pipeline=train_pipeline, + backend_args=_base_.backend_args)) + +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args=_base_.backend_args), + dict(type='Resize', scale=(1333, 800), keep_ratio=True), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader diff --git a/projects/CO-DETR/configs/codino/co_dino_5scale_r50_lsj_8xb2_1x_coco.py b/projects/CO-DETR/configs/codino/co_dino_5scale_r50_lsj_8xb2_1x_coco.py new file mode 100644 index 00000000000..876b90f89c8 --- /dev/null +++ b/projects/CO-DETR/configs/codino/co_dino_5scale_r50_lsj_8xb2_1x_coco.py @@ -0,0 +1,359 @@ +_base_ = 'mmdet::common/ssj_scp_270k_coco-instance.py' + +custom_imports = dict( + imports=['projects.CO-DETR.codetr'], allow_failed_imports=False) + +# model settings +num_dec_layer = 6 +loss_lambda = 2.0 +num_classes = 80 + +image_size = (1024, 1024) +batch_augments = [ + dict(type='BatchFixedSizePad', size=image_size, pad_mask=True) +] +model = dict( + type='CoDETR', + # If using the lsj augmentation, + # it is recommended to set it to True. + use_lsj=True, + # detr: 52.1 + # one-stage: 49.4 + # two-stage: 47.9 + eval_module='detr', # in ['detr', 'one-stage', 'two-stage'] + data_preprocessor=dict( + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + pad_mask=True, + batch_augments=batch_augments), + backbone=dict( + type='ResNet', + depth=50, + num_stages=4, + out_indices=(0, 1, 2, 3), + frozen_stages=1, + norm_cfg=dict(type='BN', requires_grad=False), + norm_eval=True, + style='pytorch', + init_cfg=dict(type='Pretrained', checkpoint='torchvision://resnet50')), + neck=dict( + type='ChannelMapper', + in_channels=[256, 512, 1024, 2048], + kernel_size=1, + out_channels=256, + act_cfg=None, + norm_cfg=dict(type='GN', num_groups=32), + num_outs=5), + query_head=dict( + type='CoDINOHead', + num_query=900, + num_classes=num_classes, + in_channels=2048, + as_two_stage=True, + dn_cfg=dict( + label_noise_scale=0.5, + box_noise_scale=1.0, + group_cfg=dict(dynamic=True, num_groups=None, num_dn_queries=100)), + transformer=dict( + type='CoDinoTransformer', + with_coord_feat=False, + num_co_heads=2, # ATSS Aux Head + Faster RCNN Aux Head + num_feature_levels=5, + encoder=dict( + type='DetrTransformerEncoder', + num_layers=6, + # number of layers that use checkpoint. + # The maximum value for the setting is num_layers. + # FairScale must be installed for it to work. + with_cp=4, + transformerlayers=dict( + type='BaseTransformerLayer', + attn_cfgs=dict( + type='MultiScaleDeformableAttention', + embed_dims=256, + num_levels=5, + dropout=0.0), + feedforward_channels=2048, + ffn_dropout=0.0, + operation_order=('self_attn', 'norm', 'ffn', 'norm'))), + decoder=dict( + type='DinoTransformerDecoder', + num_layers=6, + return_intermediate=True, + transformerlayers=dict( + type='DetrTransformerDecoderLayer', + attn_cfgs=[ + dict( + type='MultiheadAttention', + embed_dims=256, + num_heads=8, + dropout=0.0), + dict( + type='MultiScaleDeformableAttention', + embed_dims=256, + num_levels=5, + dropout=0.0), + ], + feedforward_channels=2048, + ffn_dropout=0.0, + operation_order=('self_attn', 'norm', 'cross_attn', 'norm', + 'ffn', 'norm')))), + positional_encoding=dict( + type='SinePositionalEncoding', + num_feats=128, + temperature=20, + normalize=True), + loss_cls=dict( # Different from the DINO + type='QualityFocalLoss', + use_sigmoid=True, + beta=2.0, + loss_weight=1.0), + loss_bbox=dict(type='L1Loss', loss_weight=5.0), + loss_iou=dict(type='GIoULoss', loss_weight=2.0)), + rpn_head=dict( + type='RPNHead', + in_channels=256, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + octave_base_scale=4, + scales_per_octave=3, + ratios=[0.5, 1.0, 2.0], + strides=[4, 8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[1.0, 1.0, 1.0, 1.0]), + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0 * num_dec_layer * loss_lambda), + loss_bbox=dict( + type='L1Loss', loss_weight=1.0 * num_dec_layer * loss_lambda)), + roi_head=[ + dict( + type='CoStandardRoIHead', + bbox_roi_extractor=dict( + type='SingleRoIExtractor', + roi_layer=dict( + type='RoIAlign', output_size=7, sampling_ratio=0), + out_channels=256, + featmap_strides=[4, 8, 16, 32, 64], + finest_scale=56), + bbox_head=dict( + type='Shared2FCBBoxHead', + in_channels=256, + fc_out_channels=1024, + roi_feat_size=7, + num_classes=num_classes, + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[0., 0., 0., 0.], + target_stds=[0.1, 0.1, 0.2, 0.2]), + reg_class_agnostic=False, + reg_decoded_bbox=True, + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=False, + loss_weight=1.0 * num_dec_layer * loss_lambda), + loss_bbox=dict( + type='GIoULoss', + loss_weight=10.0 * num_dec_layer * loss_lambda))) + ], + bbox_head=[ + dict( + type='CoATSSHead', + num_classes=num_classes, + in_channels=256, + stacked_convs=1, + feat_channels=256, + anchor_generator=dict( + type='AnchorGenerator', + ratios=[1.0], + octave_base_scale=8, + scales_per_octave=1, + strides=[4, 8, 16, 32, 64, 128]), + bbox_coder=dict( + type='DeltaXYWHBBoxCoder', + target_means=[.0, .0, .0, .0], + target_stds=[0.1, 0.1, 0.2, 0.2]), + loss_cls=dict( + type='FocalLoss', + use_sigmoid=True, + gamma=2.0, + alpha=0.25, + loss_weight=1.0 * num_dec_layer * loss_lambda), + loss_bbox=dict( + type='GIoULoss', + loss_weight=2.0 * num_dec_layer * loss_lambda), + loss_centerness=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + loss_weight=1.0 * num_dec_layer * loss_lambda)), + ], + # model training and testing settings + train_cfg=[ + dict( + assigner=dict( + type='HungarianAssigner', + match_costs=[ + dict(type='FocalLossCost', weight=2.0), + dict(type='BBoxL1Cost', weight=5.0, box_format='xywh'), + dict(type='IoUCost', iou_mode='giou', weight=2.0) + ])), + dict( + rpn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.7, + neg_iou_thr=0.3, + min_pos_iou=0.3, + match_low_quality=True, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=256, + pos_fraction=0.5, + neg_pos_ub=-1, + add_gt_as_proposals=False), + allowed_border=-1, + pos_weight=-1, + debug=False), + rpn_proposal=dict( + nms_pre=4000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + assigner=dict( + type='MaxIoUAssigner', + pos_iou_thr=0.5, + neg_iou_thr=0.5, + min_pos_iou=0.5, + match_low_quality=False, + ignore_iof_thr=-1), + sampler=dict( + type='RandomSampler', + num=512, + pos_fraction=0.25, + neg_pos_ub=-1, + add_gt_as_proposals=True), + pos_weight=-1, + debug=False)), + dict( + assigner=dict(type='ATSSAssigner', topk=9), + allowed_border=-1, + pos_weight=-1, + debug=False) + ], + test_cfg=[ + # Deferent from the DINO, we use the NMS. + dict( + max_per_img=300, + # NMS can improve the mAP by 0.2. + nms=dict(type='soft_nms', iou_threshold=0.8)), + dict( + rpn=dict( + nms_pre=1000, + max_per_img=1000, + nms=dict(type='nms', iou_threshold=0.7), + min_bbox_size=0), + rcnn=dict( + score_thr=0.0, + nms=dict(type='nms', iou_threshold=0.5), + max_per_img=100)), + dict( + # atss bbox head: + nms_pre=1000, + min_bbox_size=0, + score_thr=0.0, + nms=dict(type='nms', iou_threshold=0.6), + max_per_img=100), + # soft-nms is also supported for rcnn testing + # e.g., nms=dict(type='soft_nms', iou_threshold=0.5, min_score=0.05) + ]) + +# LSJ + CopyPaste +load_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict( + type='RandomResize', + scale=image_size, + ratio_range=(0.1, 2.0), + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=image_size, + recompute_bbox=True, + allow_negative_crop=True), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), + dict(type='RandomFlip', prob=0.5), + dict(type='Pad', size=image_size, pad_val=dict(img=(114, 114, 114))), +] + +train_pipeline = [ + dict(type='CopyPaste', max_num_pasted=100), + dict(type='PackDetInputs') +] + +train_dataloader = dict( + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + pipeline=train_pipeline, + dataset=dict( + filter_cfg=dict(filter_empty_gt=False), pipeline=load_pipeline))) + +# follow ViTDet +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=image_size, keep_ratio=True), # diff + dict(type='Pad', size=image_size, pad_val=dict(img=(114, 114, 114))), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +optim_wrapper = dict( + _delete_=True, + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=2e-4, weight_decay=0.0001), + clip_grad=dict(max_norm=0.1, norm_type=2), + paramwise_cfg=dict(custom_keys={'backbone': dict(lr_mult=0.1)})) + +val_evaluator = dict(metric='bbox') +test_evaluator = val_evaluator + +max_epochs = 12 +train_cfg = dict( + _delete_=True, + type='EpochBasedTrainLoop', + max_epochs=max_epochs, + val_interval=1) + +param_scheduler = [ + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[11], + gamma=0.1) +] + +default_hooks = dict( + checkpoint=dict(by_epoch=True, interval=1, max_keep_ckpts=3)) +log_processor = dict(by_epoch=True) + +# NOTE: `auto_scale_lr` is for automatically scaling LR, +# USER SHOULD NOT CHANGE ITS VALUES. +# base_batch_size = (8 GPUs) x (2 samples per GPU) +auto_scale_lr = dict(base_batch_size=16) diff --git a/projects/CO-DETR/configs/codino/co_dino_5scale_r50_lsj_8xb2_3x_coco.py b/projects/CO-DETR/configs/codino/co_dino_5scale_r50_lsj_8xb2_3x_coco.py new file mode 100644 index 00000000000..9a9fc34f680 --- /dev/null +++ b/projects/CO-DETR/configs/codino/co_dino_5scale_r50_lsj_8xb2_3x_coco.py @@ -0,0 +1,4 @@ +_base_ = ['co_dino_5scale_r50_lsj_8xb2_1x_coco.py'] + +param_scheduler = [dict(milestones=[30])] +train_cfg = dict(max_epochs=36) diff --git a/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_16e_o365tococo.py b/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_16e_o365tococo.py new file mode 100644 index 00000000000..8fdb73269ff --- /dev/null +++ b/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_16e_o365tococo.py @@ -0,0 +1,115 @@ +_base_ = ['co_dino_5scale_r50_8xb2_1x_coco.py'] + +pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth' # noqa +load_from = 'https://download.openmmlab.com/mmdetection/v3.0/codetr/co_dino_5scale_swin_large_22e_o365-0a33e247.pth' # noqa + +# model settings +model = dict( + backbone=dict( + _delete_=True, + type='SwinTransformer', + pretrain_img_size=384, + embed_dims=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=12, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + patch_norm=True, + out_indices=(0, 1, 2, 3), + # Please only add indices that would be used + # in FPN, otherwise some parameter will not be used + with_cp=True, + convert_weights=True, + init_cfg=dict(type='Pretrained', checkpoint=pretrained)), + neck=dict(in_channels=[192, 384, 768, 1536]), + query_head=dict( + dn_cfg=dict(box_noise_scale=0.4, group_cfg=dict(num_dn_queries=500)), + transformer=dict(encoder=dict(with_cp=6)))) + +train_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='RandomFlip', prob=0.5), + dict( + type='RandomChoice', + transforms=[ + [ + dict( + type='RandomChoiceResize', + scales=[(480, 2048), (512, 2048), (544, 2048), (576, 2048), + (608, 2048), (640, 2048), (672, 2048), (704, 2048), + (736, 2048), (768, 2048), (800, 2048), (832, 2048), + (864, 2048), (896, 2048), (928, 2048), (960, 2048), + (992, 2048), (1024, 2048), (1056, 2048), + (1088, 2048), (1120, 2048), (1152, 2048), + (1184, 2048), (1216, 2048), (1248, 2048), + (1280, 2048), (1312, 2048), (1344, 2048), + (1376, 2048), (1408, 2048), (1440, 2048), + (1472, 2048), (1504, 2048), (1536, 2048)], + keep_ratio=True) + ], + [ + dict( + type='RandomChoiceResize', + # The radio of all image in train dataset < 7 + # follow the original implement + scales=[(400, 4200), (500, 4200), (600, 4200)], + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=(384, 600), + allow_negative_crop=True), + dict( + type='RandomChoiceResize', + scales=[(480, 2048), (512, 2048), (544, 2048), (576, 2048), + (608, 2048), (640, 2048), (672, 2048), (704, 2048), + (736, 2048), (768, 2048), (800, 2048), (832, 2048), + (864, 2048), (896, 2048), (928, 2048), (960, 2048), + (992, 2048), (1024, 2048), (1056, 2048), + (1088, 2048), (1120, 2048), (1152, 2048), + (1184, 2048), (1216, 2048), (1248, 2048), + (1280, 2048), (1312, 2048), (1344, 2048), + (1376, 2048), (1408, 2048), (1440, 2048), + (1472, 2048), (1504, 2048), (1536, 2048)], + keep_ratio=True) + ] + ]), + dict(type='PackDetInputs') +] + +train_dataloader = dict( + batch_size=1, num_workers=1, dataset=dict(pipeline=train_pipeline)) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=(2048, 1280), keep_ratio=True), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +optim_wrapper = dict(optimizer=dict(lr=1e-4)) + +max_epochs = 16 +train_cfg = dict(max_epochs=max_epochs) + +param_scheduler = [ + dict( + type='MultiStepLR', + begin=0, + end=max_epochs, + by_epoch=True, + milestones=[8], + gamma=0.1) +] diff --git a/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_1x_coco.py b/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_1x_coco.py new file mode 100644 index 00000000000..d4a873464d4 --- /dev/null +++ b/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_1x_coco.py @@ -0,0 +1,31 @@ +_base_ = ['co_dino_5scale_r50_8xb2_1x_coco.py'] + +pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth' # noqa + +# model settings +model = dict( + backbone=dict( + _delete_=True, + type='SwinTransformer', + pretrain_img_size=384, + embed_dims=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=12, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + patch_norm=True, + out_indices=(0, 1, 2, 3), + # Please only add indices that would be used + # in FPN, otherwise some parameter will not be used + with_cp=False, + convert_weights=True, + init_cfg=dict(type='Pretrained', checkpoint=pretrained)), + neck=dict(in_channels=[192, 384, 768, 1536]), + query_head=dict(transformer=dict(encoder=dict(with_cp=6)))) + +train_dataloader = dict(batch_size=1, num_workers=1) diff --git a/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_3x_coco.py b/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_3x_coco.py new file mode 100644 index 00000000000..c2fce29b98b --- /dev/null +++ b/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_16xb1_3x_coco.py @@ -0,0 +1,6 @@ +_base_ = ['co_dino_5scale_swin_l_16xb1_1x_coco.py'] +# model settings +model = dict(backbone=dict(drop_path_rate=0.6)) + +param_scheduler = [dict(milestones=[30])] +train_cfg = dict(max_epochs=36) diff --git a/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_lsj_16xb1_1x_coco.py b/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_lsj_16xb1_1x_coco.py new file mode 100644 index 00000000000..4a9b3688b8e --- /dev/null +++ b/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_lsj_16xb1_1x_coco.py @@ -0,0 +1,72 @@ +_base_ = ['co_dino_5scale_r50_lsj_8xb2_1x_coco.py'] + +image_size = (1280, 1280) +batch_augments = [ + dict(type='BatchFixedSizePad', size=image_size, pad_mask=True) +] +pretrained = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_large_patch4_window12_384_22k.pth' # noqa + +# model settings +model = dict( + data_preprocessor=dict(batch_augments=batch_augments), + backbone=dict( + _delete_=True, + type='SwinTransformer', + pretrain_img_size=384, + embed_dims=192, + depths=[2, 2, 18, 2], + num_heads=[6, 12, 24, 48], + window_size=12, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + patch_norm=True, + out_indices=(0, 1, 2, 3), + # Please only add indices that would be used + # in FPN, otherwise some parameter will not be used + with_cp=False, + convert_weights=True, + init_cfg=dict(type='Pretrained', checkpoint=pretrained)), + neck=dict(in_channels=[192, 384, 768, 1536]), + query_head=dict(transformer=dict(encoder=dict(with_cp=6)))) + +load_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict( + type='RandomResize', + scale=image_size, + ratio_range=(0.1, 2.0), + keep_ratio=True), + dict( + type='RandomCrop', + crop_type='absolute_range', + crop_size=image_size, + recompute_bbox=True, + allow_negative_crop=True), + dict(type='FilterAnnotations', min_gt_bbox_wh=(1e-2, 1e-2)), + dict(type='RandomFlip', prob=0.5), + dict(type='Pad', size=image_size, pad_val=dict(img=(114, 114, 114))), +] + +train_dataloader = dict( + batch_size=1, + num_workers=1, + dataset=dict(dataset=dict(pipeline=load_pipeline))) + +test_pipeline = [ + dict(type='LoadImageFromFile'), + dict(type='Resize', scale=image_size, keep_ratio=True), + dict(type='Pad', size=image_size, pad_val=dict(img=(114, 114, 114))), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +val_dataloader = dict(dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader diff --git a/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_lsj_16xb1_3x_coco.py b/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_lsj_16xb1_3x_coco.py new file mode 100644 index 00000000000..0e5c00b2182 --- /dev/null +++ b/projects/CO-DETR/configs/codino/co_dino_5scale_swin_l_lsj_16xb1_3x_coco.py @@ -0,0 +1,6 @@ +_base_ = ['co_dino_5scale_swin_l_lsj_16xb1_1x_coco.py'] + +model = dict(backbone=dict(drop_path_rate=0.5)) + +param_scheduler = [dict(milestones=[30])] +train_cfg = dict(max_epochs=36) diff --git a/tools/model_converters/glip_to_mmdet.py b/tools/model_converters/glip_to_mmdet.py index 55814d6371b..255addca5bd 100644 --- a/tools/model_converters/glip_to_mmdet.py +++ b/tools/model_converters/glip_to_mmdet.py @@ -97,8 +97,7 @@ def convert(ckpt): def main(): parser = argparse.ArgumentParser( - description='Convert keys in pretrained eva ' - 'models to mmpretrain style.') + description='Convert keys to mmdet style.') parser.add_argument( 'src', default='glip_a_tiny_o365.pth', help='src model path or url') # The dst path must be a full path of the new checkpoint. diff --git a/tools/model_converters/swinv1_to_mmdet.py b/tools/model_converters/swinv1_to_mmdet.py new file mode 100644 index 00000000000..5de98f464a5 --- /dev/null +++ b/tools/model_converters/swinv1_to_mmdet.py @@ -0,0 +1,86 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import subprocess +from collections import OrderedDict + +import torch +from mmengine.runner import CheckpointLoader + + +def swin_converter(ckpt): + + new_ckpt = OrderedDict() + + def correct_unfold_reduction_order(x): + out_channel, in_channel = x.shape + x = x.reshape(out_channel, 4, in_channel // 4) + x = x[:, [0, 2, 1, 3], :].transpose(1, + 2).reshape(out_channel, in_channel) + return x + + def correct_unfold_norm_order(x): + in_channel = x.shape[0] + x = x.reshape(4, in_channel // 4) + x = x[[0, 2, 1, 3], :].transpose(0, 1).reshape(in_channel) + return x + + for k, v in ckpt.items(): + if k.startswith('backbone.layers'): + new_v = v + if 'attn.' in k: + new_k = k.replace('attn.', 'attn.w_msa.') + elif 'mlp.' in k: + if 'mlp.fc1.' in k: + new_k = k.replace('mlp.fc1.', 'ffn.layers.0.0.') + elif 'mlp.fc2.' in k: + new_k = k.replace('mlp.fc2.', 'ffn.layers.1.') + else: + new_k = k.replace('mlp.', 'ffn.') + elif 'downsample' in k: + new_k = k + if 'reduction.' in k: + new_v = correct_unfold_reduction_order(v) + elif 'norm.' in k: + new_v = correct_unfold_norm_order(v) + else: + new_k = k + new_k = new_k.replace('layers', 'stages', 1) + elif k.startswith('backbone.patch_embed'): + new_v = v + if 'proj' in k: + new_k = k.replace('proj', 'projection') + else: + new_k = k + else: + new_v = v + new_k = k + + new_ckpt[new_k] = new_v + + return new_ckpt + + +def main(): + parser = argparse.ArgumentParser( + description='Convert keys to mmdet style.') + parser.add_argument('src', help='src model path or url') + # The dst path must be a full path of the new checkpoint. + parser.add_argument('dst', help='save path') + args = parser.parse_args() + + checkpoint = CheckpointLoader.load_checkpoint(args.src, map_location='cpu') + + if 'state_dict' in checkpoint: + state_dict = checkpoint['state_dict'] + else: + state_dict = checkpoint + torch.save(swin_converter(state_dict), args.dst) + + sha = subprocess.check_output(['sha256sum', args.dst]).decode() + final_file = args.dst.replace('.pth', '') + '-{}.pth'.format(sha[:8]) + subprocess.Popen(['mv', args.dst, final_file]) + print(f'Done!!, save to {final_file}') + + +if __name__ == '__main__': + main()