From e590bfebf5eb9c68c4e7f8a297c31e3be5af3219 Mon Sep 17 00:00:00 2001 From: VVsssssk Date: Thu, 1 Sep 2022 19:47:23 +0800 Subject: [PATCH 01/16] add deploy.yaml --- .github/workflows/deploy.yaml | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 .github/workflows/deploy.yaml diff --git a/.github/workflows/deploy.yaml b/.github/workflows/deploy.yaml new file mode 100644 index 0000000000..63caacfa74 --- /dev/null +++ b/.github/workflows/deploy.yaml @@ -0,0 +1,27 @@ + +name: deploy + +on: push + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-n-publish: + runs-on: ubuntu-18.04 + if: startsWith(github.event.ref, 'refs/tags') + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.7 + uses: actions/setup-python@v2 + with: + python-version: 3.7 + - name: Install torch + run: pip install torch + - name: Build MMDet3D + run: python setup.py sdist + - name: Publish distribution to PyPI + run: | + pip install twine + twine upload dist/* -u __token__ -p ${{ secrets.pypi_password }} From f3a6f40d7792770c8861b604bb1ff661dd5cd383 Mon Sep 17 00:00:00 2001 From: VVsssssk Date: Thu, 8 Sep 2022 14:06:51 +0800 Subject: [PATCH 02/16] fix --- configs/_base_/models/point_rcnn.py | 35 +++-- .../point-rcnn_8xb2_kitti-3d-3class.py | 50 +++--- mmdet3d/datasets/transforms/loading.py | 5 +- mmdet3d/models/dense_heads/parta2_rpn_head.py | 9 +- mmdet3d/models/dense_heads/point_rpn_head.py | 145 ++++++++++++++---- mmdet3d/models/detectors/point_rcnn.py | 127 +++------------ mmdet3d/models/detectors/two_stage.py | 5 +- .../bbox_heads/point_rcnn_bbox_head.py | 18 ++- .../roi_heads/part_aggregation_roi_head.py | 16 +- .../models/roi_heads/point_rcnn_roi_head.py | 89 ++++++----- .../test_detectors/test_pointrcnn.py | 47 ++++++ 11 files changed, 301 insertions(+), 245 deletions(-) create mode 100644 tests/test_models/test_detectors/test_pointrcnn.py diff --git a/configs/_base_/models/point_rcnn.py b/configs/_base_/models/point_rcnn.py index 02a1414f7d..38219e9a02 100644 --- a/configs/_base_/models/point_rcnn.py +++ b/configs/_base_/models/point_rcnn.py @@ -1,5 +1,6 @@ model = dict( type='PointRCNN', + data_preprocessor=dict(type='Det3DDataPreprocessor'), backbone=dict( type='PointNet2SAMSG', in_channels=4, @@ -34,14 +35,14 @@ cls_linear_channels=(256, 256), reg_linear_channels=(256, 256)), cls_loss=dict( - type='FocalLoss', + type='mmdet.FocalLoss', use_sigmoid=True, reduction='sum', gamma=2.0, alpha=0.25, loss_weight=1.0), bbox_loss=dict( - type='SmoothL1Loss', + type='mmdet.SmoothL1Loss', beta=1.0 / 9.0, reduction='sum', loss_weight=1.0), @@ -55,12 +56,22 @@ 1.73]])), roi_head=dict( type='PointRCNNRoIHead', - point_roi_extractor=dict( + bbox_roi_extractor=dict( type='Single3DRoIPointExtractor', roi_layer=dict(type='RoIPointPool3d', num_sampled_points=512)), bbox_head=dict( type='PointRCNNBboxHead', num_classes=1, + loss_bbox=dict( + type='mmdet.SmoothL1Loss', + beta=1.0 / 9.0, + reduction='sum', + loss_weight=1.0), + loss_cls=dict( + type='mmdet.CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0), pred_layer_cfg=dict( in_channels=512, cls_conv_channels=(256, 256), @@ -79,13 +90,13 @@ train_cfg=dict( pos_distance_thr=10.0, rpn=dict( - nms_cfg=dict( - use_rotate_nms=True, iou_thr=0.8, nms_pre=9000, nms_post=512), - score_thr=None), + rpn_proposal=dict( + use_rotate_nms=True,score_thr=None, iou_thr=0.8, nms_pre=9000, nms_post=512)), + rcnn=dict( assigner=[ dict( # for Car - type='MaxIoUAssigner', + type='Max3DIoUAssigner', iou_calculator=dict( type='BboxOverlaps3D', coordinate='lidar'), pos_iou_thr=0.55, @@ -94,7 +105,7 @@ ignore_iof_thr=-1, match_low_quality=False), dict( # for Pedestrian - type='MaxIoUAssigner', + type='Max3DIoUAssigner', iou_calculator=dict( type='BboxOverlaps3D', coordinate='lidar'), pos_iou_thr=0.55, @@ -103,7 +114,7 @@ ignore_iof_thr=-1, match_low_quality=False), dict( # for Cyclist - type='MaxIoUAssigner', + type='Max3DIoUAssigner', iou_calculator=dict( type='BboxOverlaps3D', coordinate='lidar'), pos_iou_thr=0.55, @@ -125,7 +136,7 @@ cls_neg_thr=0.25)), test_cfg=dict( rpn=dict( - nms_cfg=dict( - use_rotate_nms=True, iou_thr=0.85, nms_pre=9000, nms_post=512), - score_thr=None), + nms_cfg=dict( + use_rotate_nms=True, iou_thr=0.85, nms_pre=9000, nms_post=512, + score_thr=None)), rcnn=dict(use_rotate_nms=True, nms_thr=0.1, score_thr=0.1))) diff --git a/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py b/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py index 7dbf35f012..305f82ae74 100644 --- a/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py +++ b/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py @@ -7,6 +7,7 @@ dataset_type = 'KittiDataset' data_root = 'data/kitti/' class_names = ['Car', 'Pedestrian', 'Cyclist'] +metainfo = dict(CLASSES=class_names) point_cloud_range = [0, -40, -3, 70.4, 40, 1] input_modality = dict(use_lidar=True, use_camera=False) @@ -42,8 +43,9 @@ dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), dict(type='PointSample', num_points=16384, sample_range=40.0), dict(type='PointShuffle'), - dict(type='DefaultFormatBundle3D', class_names=class_names), - dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) + dict( + type='Pack3DDetInputs', + keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) ] test_pipeline = [ dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), @@ -61,39 +63,23 @@ dict(type='RandomFlip3D'), dict( type='PointsRangeFilter', point_cloud_range=point_cloud_range), - dict(type='PointSample', num_points=16384, sample_range=40.0), - dict( - type='DefaultFormatBundle3D', - class_names=class_names, - with_label=False), - dict(type='Collect3D', keys=['points']) - ]) + dict(type='PointSample', num_points=16384, sample_range=40.0) + ]), + dict(type='Pack3DDetInputs', keys=['points']) ] +train_dataloader = dict(batch_size=2, + num_workers=2, + dataset=dict(type='RepeatDataset', + times=2, + dataset=dict(pipeline=train_pipeline, + metainfo=metainfo))) +test_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo)) +val_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo)) -data = dict( - samples_per_gpu=2, - workers_per_gpu=2, - train=dict( - type='RepeatDataset', - times=2, - dataset=dict(pipeline=train_pipeline, classes=class_names)), - val=dict(pipeline=test_pipeline, classes=class_names), - test=dict(pipeline=test_pipeline, classes=class_names)) - -# optimizer lr = 0.001 # max learning rate -optimizer = dict(lr=lr, betas=(0.95, 0.85)) -# runtime settings -runner = dict(type='EpochBasedRunner', max_epochs=80) -evaluation = dict(interval=2) -# yapf:disable -log_config = dict( - interval=30, - hooks=[ - dict(type='TextLoggerHook'), - dict(type='TensorboardLoggerHook') - ]) -# yapf:enable +optim_wrapper = dict(optimizer=dict(lr=lr, betas=(0.95, 0.85))) +train_cfg = dict(by_epoch=True, max_epochs=80, val_interval=2) + # Default setting for scaling LR automatically # - `enable` means enable scaling LR automatically # or not by default. diff --git a/mmdet3d/datasets/transforms/loading.py b/mmdet3d/datasets/transforms/loading.py index 615c1d74f3..0a60bd111e 100644 --- a/mmdet3d/datasets/transforms/loading.py +++ b/mmdet3d/datasets/transforms/loading.py @@ -424,7 +424,10 @@ def __init__( self.load_dim = load_dim self.use_dim = use_dim self.file_client_args = file_client_args.copy() - self.file_client = None + if self.file_client_args is not None: + self.file_client = mmengine.FileClient(**self.file_client_args) + else: + self.file_client = None def _load_points(self, pts_filename: str) -> np.ndarray: """Private function to load point clouds data. diff --git a/mmdet3d/models/dense_heads/parta2_rpn_head.py b/mmdet3d/models/dense_heads/parta2_rpn_head.py index 8af3f256b7..a81682caff 100644 --- a/mmdet3d/models/dense_heads/parta2_rpn_head.py +++ b/mmdet3d/models/dense_heads/parta2_rpn_head.py @@ -183,7 +183,7 @@ def _predict_by_feat_single(self, result = self.class_agnostic_nms(mlvl_bboxes, mlvl_bboxes_for_nms, mlvl_max_scores, mlvl_label_pred, mlvl_cls_score, mlvl_dir_scores, - score_thr, cfg.nms_post, cfg, + score_thr, cfg, input_meta) return result @@ -275,7 +275,7 @@ def class_agnostic_nms(self, mlvl_bboxes: Tensor, mlvl_bboxes_for_nms: Tensor, mlvl_max_scores: Tensor, mlvl_label_pred: Tensor, mlvl_cls_score: Tensor, mlvl_dir_scores: Tensor, - score_thr: int, max_num: int, cfg: ConfigDict, + score_thr: int, cfg: ConfigDict, input_meta: dict) -> Dict: """Class agnostic nms for single batch. @@ -291,7 +291,6 @@ def class_agnostic_nms(self, mlvl_bboxes: Tensor, mlvl_dir_scores (torch.Tensor): Direction scores of Multi-level bbox. score_thr (int): Score threshold. - max_num (int): Max number of bboxes after nms. cfg (:obj:`ConfigDict`): Training or testing config. input_meta (dict): Contain pcd and img's meta info. @@ -339,9 +338,9 @@ def class_agnostic_nms(self, mlvl_bboxes: Tensor, scores = torch.cat(scores, dim=0) cls_scores = torch.cat(cls_scores, dim=0) labels = torch.cat(labels, dim=0) - if bboxes.shape[0] > max_num: + if bboxes.shape[0] > cfg.nms_post: _, inds = scores.sort(descending=True) - inds = inds[:max_num] + inds = inds[: cfg.nms_post] bboxes = bboxes[inds, :] labels = labels[inds] scores = scores[inds] diff --git a/mmdet3d/models/dense_heads/point_rpn_head.py b/mmdet3d/models/dense_heads/point_rpn_head.py index 3276830afe..6b0f0d1c40 100644 --- a/mmdet3d/models/dense_heads/point_rpn_head.py +++ b/mmdet3d/models/dense_heads/point_rpn_head.py @@ -2,7 +2,7 @@ import torch from mmengine.model import BaseModule from torch import nn as nn - +from torch import Tensor from mmdet3d.models.builder import build_loss from mmdet3d.models.layers import nms_bev, nms_normal_bev from mmdet3d.registry import MODELS, TASK_UTILS @@ -10,7 +10,10 @@ from mmdet3d.structures.bbox_3d import (DepthInstance3DBoxes, LiDARInstance3DBoxes) from mmdet.models.utils import multi_apply - +from mmengine.structures import InstanceData +from typing import Dict, Tuple, List +from mmdet3d.structures.det3d_data_sample import SampleList +from mmdet3d.utils.typing import InstanceList @MODELS.register_module() class PointRPNHead(BaseModule): @@ -102,7 +105,7 @@ def _get_reg_out_channels(self): # torch.cos(yaw) (1), torch.sin(yaw) (1) return self.bbox_coder.code_size - def forward(self, feat_dict): + def forward(self, feat_dict: dict) -> Tuple[list]: """Forward pass. Args: @@ -112,7 +115,7 @@ def forward(self, feat_dict): tuple[list[torch.Tensor]]: Predicted boxes and classification scores. """ - point_features = feat_dict['fp_features'] + point_features = feat_dict['features'] point_features = point_features.permute(0, 2, 1).contiguous() batch_size = point_features.shape[0] feat_cls = point_features.view(-1, point_features.shape[-1]) @@ -124,13 +127,13 @@ def forward(self, feat_dict): batch_size, -1, self._get_reg_out_channels()) return point_box_preds, point_cls_preds - def loss(self, - bbox_preds, - cls_preds, + def loss_by_feat(self, + bbox_preds: List[Tensor], + cls_preds: List[Tensor], points, - gt_bboxes_3d, - gt_labels_3d, - img_metas=None): + batch_gt_instances_3d, + batch_input_metas=None, + batch_gt_instances_ignore= None): """Compute loss. Args: @@ -147,7 +150,7 @@ def loss(self, Returns: dict: Losses of PointRCNN RPN module. """ - targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d) + targets = self.get_targets(points, batch_gt_instances_3d) (bbox_targets, mask_targets, positive_mask, negative_mask, box_loss_weights, point_targets) = targets @@ -169,7 +172,7 @@ def loss(self, return losses - def get_targets(self, points, gt_bboxes_3d, gt_labels_3d): + def get_targets(self, points, batch_gt_instances_3d): """Generate targets of PointRCNN RPN head. Args: @@ -181,6 +184,8 @@ def get_targets(self, points, gt_bboxes_3d, gt_labels_3d): Returns: tuple[torch.Tensor]: Targets of PointRCNN RPN head. """ + gt_labels_3d = [instances.labels_3d for instances in batch_gt_instances_3d] + gt_bboxes_3d = [instances.bboxes_3d for instances in batch_gt_instances_3d] # find empty example for index in range(len(gt_labels_3d)): if len(gt_labels_3d[index]) == 0: @@ -243,11 +248,12 @@ def get_targets_single(self, points, gt_bboxes_3d, gt_labels_3d): return (bbox_targets, mask_targets, positive_mask, negative_mask, point_targets) - def get_bboxes(self, + def predict_by_feat(self, points, bbox_preds, cls_preds, - input_metas, + batch_input_metas, + cfg, rescale=False): """Generate bboxes from RPN head predictions. @@ -273,16 +279,19 @@ def get_bboxes(self, object_class[b]) bbox_selected, score_selected, labels, cls_preds_selected = \ self.class_agnostic_nms(obj_scores[b], sem_scores[b], bbox3d, - points[b, ..., :3], input_metas[b]) - bbox = input_metas[b]['box_type_3d']( - bbox_selected.clone(), - box_dim=bbox_selected.shape[-1], - with_yaw=True) - results.append((bbox, score_selected, labels, cls_preds_selected)) + points[b, ..., :3], batch_input_metas[b],cfg.nms_cfg) + bbox_selected = batch_input_metas[b]['box_type_3d']( + bbox_selected, box_dim=bbox_selected.shape[-1]) + result = InstanceData() + result.bboxes_3d = bbox_selected + result.scores_3d = score_selected + result.labels_3d = labels + result.cls_preds = cls_preds_selected + results.append(result) return results def class_agnostic_nms(self, obj_scores, sem_scores, bbox, points, - input_meta): + input_meta,nms_cfg): """Class agnostic nms. Args: @@ -293,8 +302,6 @@ def class_agnostic_nms(self, obj_scores, sem_scores, bbox, points, Returns: tuple[torch.Tensor]: Bounding boxes, scores and labels. """ - nms_cfg = self.test_cfg.nms_cfg if not self.training \ - else self.train_cfg.nms_cfg if nms_cfg.use_rotate_nms: nms_func = nms_bev else: @@ -323,14 +330,14 @@ def class_agnostic_nms(self, obj_scores, sem_scores, bbox, points, bbox = bbox[nonempty_box_mask] - if self.test_cfg.score_thr is not None: - score_thr = self.test_cfg.score_thr + if nms_cfg.score_thr is not None: + score_thr = nms_cfg.score_thr keep = (obj_scores >= score_thr) obj_scores = obj_scores[keep] sem_scores = sem_scores[keep] bbox = bbox.tensor[keep] - if obj_scores.shape[0] > 0: + if bbox.tensor.shape[0] > 0: topk = min(nms_cfg.nms_pre, obj_scores.shape[0]) obj_scores_nms, indices = torch.topk(obj_scores, k=topk) bbox_for_nms = xywhr2xyxyr(bbox[indices].bev) @@ -343,12 +350,18 @@ def class_agnostic_nms(self, obj_scores, sem_scores, bbox, points, score_selected = obj_scores_nms[keep] cls_preds = sem_scores_nms[keep] labels = torch.argmax(cls_preds, -1) + if bbox_selected.shape[0] > nms_cfg.nms_post: + _, inds = score_selected.sort(descending=True) + inds = inds[: score_selected.nms_post] + bbox_selected = bbox_selected[inds, :] + labels = labels[inds] + score_selected = score_selected[inds] + cls_preds = cls_preds[inds,:] else: bbox_selected = bbox.tensor score_selected = obj_scores.new_zeros([0]) labels = obj_scores.new_zeros([0]) cls_preds = obj_scores.new_zeros([0, sem_scores.shape[-1]]) - return bbox_selected, score_selected, labels, cls_preds def _assign_targets_by_points_inside(self, bboxes_3d, points): @@ -379,3 +392,81 @@ def _assign_targets_by_points_inside(self, bboxes_3d, points): raise NotImplementedError('Unsupported bbox type!') return points_mask, assignment + + def predict(self, feats_dict: Dict, + batch_data_samples: SampleList) -> InstanceList: + """Perform forward propagation of the 3D detection head and predict + detection results on the features of the upstream network. + + Args: + feats_dict (dict): Contains features from the first stage. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + samples. It usually includes information such as + `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`. + + Returns: + list[:obj:`InstanceData`]: Detection results of each sample + after the post process. + Each item usually contains following keys. + + - scores_3d (Tensor): Classification scores, has a shape + (num_instances, ) + - labels_3d (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes, + contains a tensor with shape (num_instances, C), where + C >= 7. + """ + batch_input_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] + stack_points = feats_dict.pop('stack_points') + bbox_preds, cls_preds = self(feats_dict) + proposal_cfg = self.test_cfg + + proposal_list = self.predict_by_feat( + stack_points, bbox_preds, cls_preds, cfg=proposal_cfg, batch_input_metas=batch_input_metas) + feats_dict['points_cls_preds'] = cls_preds + return proposal_list + + def loss_and_predict(self, + feats_dict: Dict, + batch_data_samples: SampleList, + proposal_cfg=None, + **kwargs): + """Perform forward propagation of the head, then calculate loss and + predictions from the features and data samples. + + Args: + feats_dict (dict): Contains features from the first stage. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + samples. It usually includes information such as + `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`. + proposal_cfg (ConfigDict, optional): Proposal config. + + Returns: + tuple: the return value is a tuple contains: + + - losses: (dict[str, Tensor]): A dictionary of loss components. + - predictions (list[:obj:`InstanceData`]): Detection + results of each sample after the post process. + """ + batch_gt_instances_3d = [] + batch_gt_instances_ignore = [] + batch_input_metas = [] + for data_sample in batch_data_samples: + batch_input_metas.append(data_sample.metainfo) + batch_gt_instances_3d.append(data_sample.gt_instances_3d) + batch_gt_instances_ignore.append( + data_sample.get('ignored_instances', None)) + stack_points = feats_dict.pop('stack_points') + bbox_preds, cls_preds = self(feats_dict) + + loss_inputs = (bbox_preds, cls_preds, stack_points) + (batch_gt_instances_3d, batch_input_metas, + batch_gt_instances_ignore) + losses = self.loss_by_feat(*loss_inputs) + + predictions = self.predict_by_feat( + stack_points, bbox_preds, cls_preds, batch_input_metas=batch_input_metas, cfg=proposal_cfg) + feats_dict['points_cls_preds'] = cls_preds + return losses, predictions \ No newline at end of file diff --git a/mmdet3d/models/detectors/point_rcnn.py b/mmdet3d/models/detectors/point_rcnn.py index c763e5e86b..967e2243b6 100644 --- a/mmdet3d/models/detectors/point_rcnn.py +++ b/mmdet3d/models/detectors/point_rcnn.py @@ -3,8 +3,7 @@ from mmdet3d.registry import MODELS from .two_stage import TwoStage3DDetector - - +from typing import Optional,Dict @MODELS.register_module() class PointRCNN(TwoStage3DDetector): r"""PointRCNN detector. @@ -23,14 +22,14 @@ class PointRCNN(TwoStage3DDetector): """ def __init__(self, - backbone, - neck=None, - rpn_head=None, - roi_head=None, - train_cfg=None, - test_cfg=None, - pretrained=None, - init_cfg=None): + backbone: dict, + neck: dict = None, + rpn_head: dict = None, + roi_head: dict = None, + train_cfg: dict = None, + test_cfg: dict = None, + init_cfg: dict = None, + data_preprocessor: Optional[dict] = None) -> Optional: super(PointRCNN, self).__init__( backbone=backbone, neck=neck, @@ -38,111 +37,25 @@ def __init__(self, roi_head=roi_head, train_cfg=train_cfg, test_cfg=test_cfg, - pretrained=pretrained, - init_cfg=init_cfg) + init_cfg=init_cfg, + data_preprocessor=data_preprocessor) - def extract_feat(self, points): + def extract_feat(self, batch_inputs_dict: Dict) -> Dict: """Directly extract features from the backbone+neck. Args: - points (torch.Tensor): Input points. + batch_inputs_dict (dict): The model input dict which include + 'points', 'imgs' keys. + + - points (list[torch.Tensor]): Point cloud of each sample. + - imgs (torch.Tensor, optional): Image of each sample. Returns: - dict: Features from the backbone+neck + dict: Features from the backbone+neck and raw points. """ + points = torch.stack(batch_inputs_dict['points']) x = self.backbone(points) if self.with_neck: x = self.neck(x) - return x - - def forward_train(self, points, input_metas, gt_bboxes_3d, gt_labels_3d): - """Forward of training. - - Args: - points (list[torch.Tensor]): Points of each batch. - input_metas (list[dict]): Meta information of each sample. - gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch. - gt_labels_3d (list[torch.Tensor]): gt class labels of each batch. - - Returns: - dict: Losses. - """ - losses = dict() - stack_points = torch.stack(points) - x = self.extract_feat(stack_points) - - # features for rcnn - backbone_feats = x['fp_features'].clone() - backbone_xyz = x['fp_xyz'].clone() - rcnn_feats = {'features': backbone_feats, 'points': backbone_xyz} - - bbox_preds, cls_preds = self.rpn_head(x) - - rpn_loss = self.rpn_head.loss( - bbox_preds=bbox_preds, - cls_preds=cls_preds, - points=points, - gt_bboxes_3d=gt_bboxes_3d, - gt_labels_3d=gt_labels_3d, - input_metas=input_metas) - losses.update(rpn_loss) - - bbox_list = self.rpn_head.get_bboxes(stack_points, bbox_preds, - cls_preds, input_metas) - proposal_list = [ - dict( - boxes_3d=bboxes, - scores_3d=scores, - labels_3d=labels, - cls_preds=preds_cls) - for bboxes, scores, labels, preds_cls in bbox_list - ] - rcnn_feats.update({'points_cls_preds': cls_preds}) - - roi_losses = self.roi_head.forward_train(rcnn_feats, input_metas, - proposal_list, gt_bboxes_3d, - gt_labels_3d) - losses.update(roi_losses) - - return losses - - def simple_test(self, points, img_metas, imgs=None, rescale=False): - """Forward of testing. - - Args: - points (list[torch.Tensor]): Points of each sample. - img_metas (list[dict]): Image metas. - imgs (list[torch.Tensor], optional): Images of each sample. - Defaults to None. - rescale (bool, optional): Whether to rescale results. - Defaults to False. - - Returns: - list: Predicted 3d boxes. - """ - stack_points = torch.stack(points) - - x = self.extract_feat(stack_points) - # features for rcnn - backbone_feats = x['fp_features'].clone() - backbone_xyz = x['fp_xyz'].clone() - rcnn_feats = {'features': backbone_feats, 'points': backbone_xyz} - bbox_preds, cls_preds = self.rpn_head(x) - rcnn_feats.update({'points_cls_preds': cls_preds}) - - bbox_list = self.rpn_head.get_bboxes( - stack_points, bbox_preds, cls_preds, img_metas, rescale=rescale) - - proposal_list = [ - dict( - boxes_3d=bboxes, - scores_3d=scores, - labels_3d=labels, - cls_preds=preds_cls) - for bboxes, scores, labels, preds_cls in bbox_list - ] - bbox_results = self.roi_head.simple_test(rcnn_feats, img_metas, - proposal_list) - - return bbox_results + return dict(features=x['fp_features'].clone(),points=x['fp_xyz'].clone(),stack_points=points) diff --git a/mmdet3d/models/detectors/two_stage.py b/mmdet3d/models/detectors/two_stage.py index 8817a355c1..17122fac65 100644 --- a/mmdet3d/models/detectors/two_stage.py +++ b/mmdet3d/models/detectors/two_stage.py @@ -100,8 +100,9 @@ def loss(self, batch_inputs_dict: dict, batch_data_samples: SampleList, keys = rpn_losses.keys() for key in keys: if 'loss' in key and 'rpn' not in key: - rpn_losses[f'rpn_{key}'] = rpn_losses.pop(key) - losses.update(rpn_losses) + losses[f'rpn_{key}'] = rpn_losses[key] + else: + losses[key] = rpn_losses[key] else: # TODO: Not support currently, should have a check at Fast R-CNN assert batch_data_samples[0].get('proposals', None) is not None diff --git a/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py index 2f74c78668..7207489700 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py +++ b/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py @@ -12,7 +12,7 @@ from mmdet3d.structures.bbox_3d import (LiDARInstance3DBoxes, rotation_3d_in_axis, xywhr2xyxyr) from mmdet.models.utils import multi_apply - +from mmengine.structures import InstanceData @MODELS.register_module() class PointRCNNBboxHead(BaseModule): @@ -449,12 +449,12 @@ def _get_target_single(self, pos_bboxes, pos_gt_bboxes, ious, cfg): return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, bbox_weights) - def get_bboxes(self, + def get_results(self, rois, cls_score, bbox_pred, class_labels, - img_metas, + input_metas, cfg=None): """Generate bboxes from bbox head predictions. @@ -494,16 +494,18 @@ def get_bboxes(self, cur_rcnn_boxes3d = rcnn_boxes3d[roi_batch_id == batch_id] keep = self.multi_class_nms(cur_box_prob, cur_rcnn_boxes3d, cfg.score_thr, cfg.nms_thr, - img_metas[batch_id], + input_metas[batch_id], cfg.use_rotate_nms) selected_bboxes = cur_rcnn_boxes3d[keep] selected_label_preds = cur_class_labels[keep] selected_scores = cur_cls_score[keep] + results = InstanceData() + results.bboxes_3d = input_metas[batch_id]['box_type_3d']( + selected_bboxes, selected_bboxes.shape[-1]) + results.scores_3d = selected_scores + results.labels_3d = selected_label_preds - result_list.append( - (img_metas[batch_id]['box_type_3d'](selected_bboxes, - self.bbox_coder.code_size), - selected_scores, selected_label_preds)) + result_list.append(results) return result_list def multi_class_nms(self, diff --git a/mmdet3d/models/roi_heads/part_aggregation_roi_head.py b/mmdet3d/models/roi_heads/part_aggregation_roi_head.py index e26c3aba6c..aed8f2e4c5 100644 --- a/mmdet3d/models/roi_heads/part_aggregation_roi_head.py +++ b/mmdet3d/models/roi_heads/part_aggregation_roi_head.py @@ -91,7 +91,8 @@ def _bbox_forward_train(self, feats_dict: Dict, voxels_dict: Dict, def _assign_and_sample( self, proposal_list: InstanceList, - batch_gt_instances_3d: InstanceList) -> List[SamplingResult]: + batch_gt_instances_3d: InstanceList, + batch_gt_instances_ignore) -> List[SamplingResult]: """Assign and sample proposals for training. Args: @@ -112,11 +113,12 @@ def _assign_and_sample( cur_boxes = cur_proposal_list['bboxes_3d'] cur_labels_3d = cur_proposal_list['labels_3d'] cur_gt_instances_3d = batch_gt_instances_3d[batch_idx] + cur_gt_instances_ignore = batch_gt_instances_ignore[batch_idx] cur_gt_instances_3d.bboxes_3d = cur_gt_instances_3d.\ bboxes_3d.tensor - cur_gt_bboxes = batch_gt_instances_3d[batch_idx].bboxes_3d.to( + cur_gt_bboxes = cur_gt_instances_3d.bboxes_3d.to( cur_boxes.device) - cur_gt_labels = batch_gt_instances_3d[batch_idx].labels_3d + cur_gt_labels = cur_gt_instances_3d.labels_3d batch_num_gts = 0 # 0 is bg @@ -132,7 +134,8 @@ def _assign_and_sample( pred_per_cls = (cur_labels_3d == i) cur_assign_res = assigner.assign( cur_proposal_list[pred_per_cls], - cur_gt_instances_3d[gt_per_cls]) + cur_gt_instances_3d[gt_per_cls], + cur_gt_instances_ignore) # gather assign_results in different class into one result batch_num_gts += cur_assign_res.num_gts # gt inds (1-based) @@ -158,7 +161,7 @@ def _assign_and_sample( batch_gt_labels) else: # for single class assign_result = self.bbox_assigner.assign( - cur_proposal_list, cur_gt_instances_3d) + cur_proposal_list, cur_gt_instances_3d,cur_gt_instances_ignore) # sample boxes sampling_result = self.bbox_sampler.sample(assign_result, cur_boxes.tensor, @@ -342,7 +345,8 @@ def loss(self, feats_dict: Dict, rpn_results_list: InstanceList, losses.update(semantic_results.pop('loss_semantic')) sample_results = self._assign_and_sample(rpn_results_list, - batch_gt_instances_3d) + batch_gt_instances_3d, + batch_gt_instances_ignore) if self.with_bbox: feats_dict.update(semantic_results) bbox_results = self._bbox_forward_train(feats_dict, voxels_dict, diff --git a/mmdet3d/models/roi_heads/point_rcnn_roi_head.py b/mmdet3d/models/roi_heads/point_rcnn_roi_head.py index dcfb1b9135..ace2882121 100644 --- a/mmdet3d/models/roi_heads/point_rcnn_roi_head.py +++ b/mmdet3d/models/roi_heads/point_rcnn_roi_head.py @@ -14,7 +14,7 @@ class PointRCNNRoIHead(Base3DRoIHead): Args: bbox_head (dict): Config of bbox_head. - point_roi_extractor (dict): Config of RoI extractor. + bbox_roi_extractor (dict): Config of RoI extractor. train_cfg (dict): Train configs. test_cfg (dict): Test configs. depth_normalizer (float, optional): Normalize depth feature. @@ -24,33 +24,21 @@ class PointRCNNRoIHead(Base3DRoIHead): def __init__(self, bbox_head, - point_roi_extractor, + bbox_roi_extractor, train_cfg, test_cfg, depth_normalizer=70.0, - pretrained=None, init_cfg=None): super(PointRCNNRoIHead, self).__init__( bbox_head=bbox_head, + bbox_roi_extractor=bbox_roi_extractor, train_cfg=train_cfg, test_cfg=test_cfg, - pretrained=pretrained, init_cfg=init_cfg) self.depth_normalizer = depth_normalizer - if point_roi_extractor is not None: - self.point_roi_extractor = MODELS.build(point_roi_extractor) - self.init_assigner_sampler() - def init_bbox_head(self, bbox_head): - """Initialize box head. - - Args: - bbox_head (dict): Config dict of RoI Head. - """ - self.bbox_head = MODELS.build(bbox_head) - def init_mask_head(self): """Initialize maek head.""" pass @@ -68,8 +56,8 @@ def init_assigner_sampler(self): ] self.bbox_sampler = TASK_UTILS.build(self.train_cfg.sampler) - def forward_train(self, feats_dict, input_metas, proposal_list, - gt_bboxes_3d, gt_labels_3d): + def loss(self, feats_dict, rpn_results_list, + batch_data_samples, **kwargs) -> dict: """Training forward function of PointRCNNRoIHead. Args: @@ -94,9 +82,15 @@ def forward_train(self, feats_dict, input_metas, proposal_list, point_cls_preds = feats_dict['points_cls_preds'] sem_scores = point_cls_preds.sigmoid() point_scores = sem_scores.max(-1)[0] - - sample_results = self._assign_and_sample(proposal_list, gt_bboxes_3d, - gt_labels_3d) + batch_gt_instances_3d = [] + batch_gt_instances_ignore = [] + for data_sample in batch_data_samples: + batch_gt_instances_3d.append(data_sample.gt_instances_3d) + if 'ignored_instances' in data_sample: + batch_gt_instances_ignore.append(data_sample.ignored_instances) + else: + batch_gt_instances_ignore.append(None) + sample_results = self._assign_and_sample(rpn_results_list, batch_gt_instances_3d,batch_gt_instances_ignore) # concat the depth, semantic features and backbone features features = features.transpose(1, 2).contiguous() @@ -114,7 +108,12 @@ def forward_train(self, feats_dict, input_metas, proposal_list, return losses - def simple_test(self, feats_dict, img_metas, proposal_list, **kwargs): + def predict(self, + feats_dict, + rpn_results_list, + batch_data_samples, + rescale: bool = False, + **kwargs): """Simple testing forward function of PointRCNNRoIHead. Note: @@ -128,9 +127,11 @@ def simple_test(self, feats_dict, img_metas, proposal_list, **kwargs): Returns: dict: Bbox results of one frame. """ - rois = bbox3d2roi([res['boxes_3d'].tensor for res in proposal_list]) - labels_3d = [res['labels_3d'] for res in proposal_list] - + rois = bbox3d2roi([res['bboxes_3d'].tensor for res in rpn_results_list]) + labels_3d = [res['labels_3d'] for res in rpn_results_list] + batch_input_metas = [ + data_samples.metainfo for data_samples in batch_data_samples + ] features = feats_dict['features'] points = feats_dict['points'] point_cls_preds = feats_dict['points_cls_preds'] @@ -148,19 +149,15 @@ def simple_test(self, feats_dict, img_metas, proposal_list, **kwargs): batch_size = features.shape[0] bbox_results = self._bbox_forward(features, points, batch_size, rois) object_score = bbox_results['cls_score'].sigmoid() - bbox_list = self.bbox_head.get_bboxes( + bbox_list = self.bbox_head.get_results( rois, object_score, bbox_results['bbox_pred'], labels_3d, - img_metas, + batch_input_metas, cfg=self.test_cfg) - bbox_results = [ - bbox3d2result(bboxes, scores, labels) - for bboxes, scores, labels in bbox_list - ] - return bbox_results + return bbox_list def _bbox_forward_train(self, features, points, sampling_results): """Forward training function of roi_extractor and bbox_head. @@ -203,14 +200,14 @@ def _bbox_forward(self, features, points, batch_size, rois): dict: Contains predictions of bbox_head and features of roi_extractor. """ - pooled_point_feats = self.point_roi_extractor(features, points, + pooled_point_feats = self.bbox_roi_extractor(features, points, batch_size, rois) cls_score, bbox_pred = self.bbox_head(pooled_point_feats) bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred) return bbox_results - def _assign_and_sample(self, proposal_list, gt_bboxes_3d, gt_labels_3d): + def _assign_and_sample(self, rpn_results_list, batch_gt_instances_3d,batch_gt_instances_ignore): """Assign and sample proposals for training. Args: @@ -225,12 +222,16 @@ def _assign_and_sample(self, proposal_list, gt_bboxes_3d, gt_labels_3d): """ sampling_results = [] # bbox assign - for batch_idx in range(len(proposal_list)): - cur_proposal_list = proposal_list[batch_idx] - cur_boxes = cur_proposal_list['boxes_3d'] + for batch_idx in range(len(rpn_results_list)): + cur_proposal_list = rpn_results_list[batch_idx] + cur_boxes = cur_proposal_list['bboxes_3d'] cur_labels_3d = cur_proposal_list['labels_3d'] - cur_gt_bboxes = gt_bboxes_3d[batch_idx].to(cur_boxes.device) - cur_gt_labels = gt_labels_3d[batch_idx] + cur_gt_instances_3d = batch_gt_instances_3d[batch_idx] + cur_gt_instances_3d.bboxes_3d = cur_gt_instances_3d.\ + bboxes_3d.tensor + cur_gt_instances_ignore = batch_gt_instances_ignore[batch_idx] + cur_gt_bboxes = cur_gt_instances_3d.bboxes_3d.to(cur_boxes.device) + cur_gt_labels = cur_gt_instances_3d.labels_3d batch_num_gts = 0 # 0 is bg batch_gt_indis = cur_gt_labels.new_full((len(cur_boxes), ), 0) @@ -244,9 +245,9 @@ def _assign_and_sample(self, proposal_list, gt_bboxes_3d, gt_labels_3d): gt_per_cls = (cur_gt_labels == i) pred_per_cls = (cur_labels_3d == i) cur_assign_res = assigner.assign( - cur_boxes.tensor[pred_per_cls], - cur_gt_bboxes.tensor[gt_per_cls], - gt_labels=cur_gt_labels[gt_per_cls]) + cur_proposal_list[pred_per_cls], + cur_gt_instances_3d[gt_per_cls], + cur_gt_instances_ignore) # gather assign_results in different class into one result batch_num_gts += cur_assign_res.num_gts # gt inds (1-based) @@ -272,14 +273,12 @@ def _assign_and_sample(self, proposal_list, gt_bboxes_3d, gt_labels_3d): batch_gt_labels) else: # for single class assign_result = self.bbox_assigner.assign( - cur_boxes.tensor, - cur_gt_bboxes.tensor, - gt_labels=cur_gt_labels) + cur_proposal_list, cur_gt_instances_3d, cur_gt_instances_ignore) # sample boxes sampling_result = self.bbox_sampler.sample(assign_result, cur_boxes.tensor, - cur_gt_bboxes.tensor, + cur_gt_bboxes, cur_gt_labels) sampling_results.append(sampling_result) return sampling_results diff --git a/tests/test_models/test_detectors/test_pointrcnn.py b/tests/test_models/test_detectors/test_pointrcnn.py new file mode 100644 index 0000000000..2e7c0cfe0c --- /dev/null +++ b/tests/test_models/test_detectors/test_pointrcnn.py @@ -0,0 +1,47 @@ +import unittest + +import torch +from mmengine import DefaultScope + +from mmdet3d.registry import MODELS +from tests.utils.model_utils import (_create_detector_inputs, + _get_detector_cfg, _setup_seed) + + +class TestPointRCNN(unittest.TestCase): + + def test_pointrcnn(self): + import mmdet3d.models + + assert hasattr(mmdet3d.models, 'PointRCNN') + DefaultScope.get_instance('test_pointrcnn', scope_name='mmdet3d') + _setup_seed(0) + pointrcnn_cfg = _get_detector_cfg( + 'point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py') + model = MODELS.build(pointrcnn_cfg) + num_gt_instance = 2 + packed_inputs = _create_detector_inputs( + num_points=10101, + num_gt_instance=num_gt_instance) + + if torch.cuda.is_available(): + model = model.cuda() + # test simple_test + with torch.no_grad(): + data = model.data_preprocessor(packed_inputs, True) + torch.cuda.empty_cache() + results = model.forward(**data, mode='predict') + self.assertEqual(len(results), 1) + self.assertIn('bboxes_3d', results[0].pred_instances_3d) + self.assertIn('scores_3d', results[0].pred_instances_3d) + self.assertIn('labels_3d', results[0].pred_instances_3d) + + # save the memory + with torch.no_grad(): + losses = model.forward(**data, mode='loss') + torch.cuda.empty_cache() + self.assertGreaterEqual(losses['rpn_bbox_loss'], 0) + self.assertGreaterEqual(losses['rpn_semantic_loss'], 0) + self.assertGreaterEqual(losses['loss_cls'], 0) + self.assertGreaterEqual(losses['loss_bbox'], 0) + self.assertGreaterEqual(losses['loss_corner'], 0) From 3b5d329d71512159b2a19bd9b1e7d3683856ba0d Mon Sep 17 00:00:00 2001 From: VVsssssk Date: Thu, 8 Sep 2022 19:59:00 +0800 Subject: [PATCH 03/16] fix --- configs/_base_/models/point_rcnn.py | 6 ++--- .../point-rcnn_8xb2_kitti-3d-3class.py | 2 +- tools/model_converters/pointrcnn_convert.py | 25 +++++++++++++++++++ 3 files changed, 29 insertions(+), 4 deletions(-) create mode 100644 tools/model_converters/pointrcnn_convert.py diff --git a/configs/_base_/models/point_rcnn.py b/configs/_base_/models/point_rcnn.py index 38219e9a02..f381990041 100644 --- a/configs/_base_/models/point_rcnn.py +++ b/configs/_base_/models/point_rcnn.py @@ -95,7 +95,7 @@ rcnn=dict( assigner=[ - dict( # for Car + dict( # for Pedestrian type='Max3DIoUAssigner', iou_calculator=dict( type='BboxOverlaps3D', coordinate='lidar'), @@ -104,7 +104,7 @@ min_pos_iou=0.55, ignore_iof_thr=-1, match_low_quality=False), - dict( # for Pedestrian + dict( # for Cyclist type='Max3DIoUAssigner', iou_calculator=dict( type='BboxOverlaps3D', coordinate='lidar'), @@ -113,7 +113,7 @@ min_pos_iou=0.55, ignore_iof_thr=-1, match_low_quality=False), - dict( # for Cyclist + dict( # for Car type='Max3DIoUAssigner', iou_calculator=dict( type='BboxOverlaps3D', coordinate='lidar'), diff --git a/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py b/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py index 305f82ae74..f4dce06436 100644 --- a/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py +++ b/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py @@ -6,7 +6,7 @@ # dataset settings dataset_type = 'KittiDataset' data_root = 'data/kitti/' -class_names = ['Car', 'Pedestrian', 'Cyclist'] +class_names = ['Pedestrian', 'Cyclist', 'Car'] metainfo = dict(CLASSES=class_names) point_cloud_range = [0, -40, -3, 70.4, 40, 1] input_modality = dict(use_lidar=True, use_camera=False) diff --git a/tools/model_converters/pointrcnn_convert.py b/tools/model_converters/pointrcnn_convert.py new file mode 100644 index 0000000000..7edda4c4f5 --- /dev/null +++ b/tools/model_converters/pointrcnn_convert.py @@ -0,0 +1,25 @@ +import torch +import mmengine + +mm_path = '/home/PJLAB/shenkun/openmmlab-refactor/mmdetection3d/checkpoints/point_rcnn_2x8_kitti-3d-3classes_20211208_151344.pth' +pc_path = '/home/PJLAB/shenkun/workspace/OpenPCDet/checkpoint/pointrcnn_7870.pth' + +def main(): + new_dict = dict() + ori = torch.load(mm_path) + mm_dict = torch.load(mm_path)['state_dict'] + pc_dict = torch.load(pc_path)['model_state'] + pc_dict.pop('global_step') + for i in range(len(mm_dict.keys())): + mm_name = list(mm_dict.keys())[i] + if 'backbone' in mm_name and 'conv' in mm_name and 'bias' in mm_name: + continue + else: + new_dict[mm_name] = mm_dict[mm_name] + for i in range(len(new_dict.keys())): + new_dict[list(new_dict.keys())[i]] = pc_dict[list(pc_dict.keys())[i]] + ori['state_dict'] = new_dict + torch.save(ori,'new_pointrcnn.pth') + +if __name__ == '__main__': + main() From 387663b5143e2c2eae68a5ba39ddb99c08afb58b Mon Sep 17 00:00:00 2001 From: VVsssssk Date: Fri, 9 Sep 2022 11:47:53 +0800 Subject: [PATCH 04/16] fic --- mmdet3d/models/backbones/pointnet2_sa_msg.py | 2 +- mmdet3d/models/dense_heads/point_rpn_head.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/mmdet3d/models/backbones/pointnet2_sa_msg.py b/mmdet3d/models/backbones/pointnet2_sa_msg.py index 18bfae7695..5e88ee01a2 100644 --- a/mmdet3d/models/backbones/pointnet2_sa_msg.py +++ b/mmdet3d/models/backbones/pointnet2_sa_msg.py @@ -105,7 +105,7 @@ def __init__(self, dilated_group=dilated_group[sa_index], norm_cfg=norm_cfg, cfg=sa_cfg, - bias=True)) + bias=False)) skip_channel_list.append(sa_out_channel) cur_aggregation_channel = aggregation_channels[sa_index] diff --git a/mmdet3d/models/dense_heads/point_rpn_head.py b/mmdet3d/models/dense_heads/point_rpn_head.py index 6b0f0d1c40..b1c6ef7bb9 100644 --- a/mmdet3d/models/dense_heads/point_rpn_head.py +++ b/mmdet3d/models/dense_heads/point_rpn_head.py @@ -3,7 +3,6 @@ from mmengine.model import BaseModule from torch import nn as nn from torch import Tensor -from mmdet3d.models.builder import build_loss from mmdet3d.models.layers import nms_bev, nms_normal_bev from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.structures import xywhr2xyxyr @@ -53,8 +52,8 @@ def __init__(self, self.enlarge_width = enlarge_width # build loss function - self.bbox_loss = build_loss(bbox_loss) - self.cls_loss = build_loss(cls_loss) + self.bbox_loss = MODELS.build(bbox_loss) + self.cls_loss = MODELS.build(cls_loss) # build box coder self.bbox_coder = TASK_UTILS.build(bbox_coder) From c36d77963d3c51f9d4a67e1bbbc955a5c3b805e5 Mon Sep 17 00:00:00 2001 From: VVsssssk Date: Fri, 9 Sep 2022 11:59:37 +0800 Subject: [PATCH 05/16] fix --- mmdet3d/models/dense_heads/ssd_3d_head.py | 5 ++--- mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py | 5 ++--- .../models/roi_heads/mask_heads/pointwise_semantic_head.py | 5 ++--- 3 files changed, 6 insertions(+), 9 deletions(-) diff --git a/mmdet3d/models/dense_heads/ssd_3d_head.py b/mmdet3d/models/dense_heads/ssd_3d_head.py index 15ffa216ae..8069c9bccf 100644 --- a/mmdet3d/models/dense_heads/ssd_3d_head.py +++ b/mmdet3d/models/dense_heads/ssd_3d_head.py @@ -14,7 +14,6 @@ LiDARInstance3DBoxes, rotation_3d_in_axis) from mmdet.models.utils import multi_apply -from ..builder import build_loss from .vote_head import VoteHead @@ -76,8 +75,8 @@ def __init__(self, size_res_loss=size_res_loss, semantic_loss=None, init_cfg=init_cfg) - self.corner_loss = build_loss(corner_loss) - self.vote_loss = build_loss(vote_loss) + self.corner_loss = MODELS.build(corner_loss) + self.vote_loss = MODELS.build(vote_loss) self.num_candidates = vote_module_cfg['num_points'] def _get_cls_out_channels(self) -> int: diff --git a/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py index 5ac565a6bc..0bc7d4f5a4 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py +++ b/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py @@ -21,7 +21,6 @@ from mmengine.model import BaseModule from torch import nn as nn -from mmdet3d.models.builder import build_loss from mmdet3d.models.layers import nms_bev, nms_normal_bev from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.structures.bbox_3d import (LiDARInstance3DBoxes, @@ -88,8 +87,8 @@ def __init__(self, self.num_classes = num_classes self.with_corner_loss = with_corner_loss self.bbox_coder = TASK_UTILS.build(bbox_coder) - self.loss_bbox = build_loss(loss_bbox) - self.loss_cls = build_loss(loss_cls) + self.loss_bbox = MODELS.build(loss_bbox) + self.loss_cls = MODELS.build(loss_cls) self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) assert down_conv_channels[-1] == shared_fc_channels[0] diff --git a/mmdet3d/models/roi_heads/mask_heads/pointwise_semantic_head.py b/mmdet3d/models/roi_heads/mask_heads/pointwise_semantic_head.py index c580713910..73622202df 100644 --- a/mmdet3d/models/roi_heads/mask_heads/pointwise_semantic_head.py +++ b/mmdet3d/models/roi_heads/mask_heads/pointwise_semantic_head.py @@ -4,7 +4,6 @@ from torch import nn as nn from torch.nn import functional as F -from mmdet3d.models.builder import build_loss from mmdet3d.registry import MODELS from mmdet3d.structures.bbox_3d import rotation_3d_in_axis from mmdet3d.utils import InstanceList @@ -50,8 +49,8 @@ def __init__(self, self.seg_cls_layer = nn.Linear(in_channels, 1, bias=True) self.seg_reg_layer = nn.Linear(in_channels, 3, bias=True) - self.loss_seg = build_loss(loss_seg) - self.loss_part = build_loss(loss_part) + self.loss_seg = MODELS.build(loss_seg) + self.loss_part = MODELS.build(loss_part) def forward(self, x): """Forward pass. From 8dec36d9c135ee0064d77452f30d05f37af8988a Mon Sep 17 00:00:00 2001 From: VVsssssk Date: Fri, 9 Sep 2022 12:00:51 +0800 Subject: [PATCH 06/16] fix --- mmdet3d/datasets/transforms/transforms_3d.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mmdet3d/datasets/transforms/transforms_3d.py b/mmdet3d/datasets/transforms/transforms_3d.py index 03b0933b2d..a61296c999 100644 --- a/mmdet3d/datasets/transforms/transforms_3d.py +++ b/mmdet3d/datasets/transforms/transforms_3d.py @@ -367,11 +367,10 @@ def transform(self, input_dict: dict) -> dict: gt_bboxes_3d = input_dict['gt_bboxes_3d'] gt_labels_3d = input_dict['gt_labels_3d'] - if self.use_ground_plane and 'plane' in input_dict['ann_info']: - ground_plane = input_dict['plane'] + if self.use_ground_plane: + ground_plane = input_dict.get('plane', None) assert ground_plane is not None, '`use_ground_plane` is True ' \ 'but find plane is None' - input_dict['plane'] = ground_plane else: ground_plane = None # change to float for blending operation From 1b585307490ddcb1dd22f2db38db207296f87872 Mon Sep 17 00:00:00 2001 From: VVsssssk Date: Fri, 9 Sep 2022 13:55:55 +0800 Subject: [PATCH 07/16] fix --- configs/_base_/models/point_rcnn.py | 16 ++- .../point-rcnn_8xb2_kitti-3d-3class.py | 13 +- .../models/dense_heads/base_3d_dense_head.py | 2 +- mmdet3d/models/dense_heads/parta2_rpn_head.py | 5 +- mmdet3d/models/dense_heads/point_rpn_head.py | 130 +++++++++++------- mmdet3d/models/detectors/point_rcnn.py | 12 +- .../roi_heads/bbox_heads/parta2_bbox_head.py | 15 +- .../bbox_heads/point_rcnn_bbox_head.py | 98 +++++++------ .../roi_heads/part_aggregation_roi_head.py | 13 +- .../models/roi_heads/point_rcnn_roi_head.py | 80 ++++++----- mmdet3d/utils/typing.py | 2 + .../test_detectors/test_pointrcnn.py | 3 +- 12 files changed, 239 insertions(+), 150 deletions(-) diff --git a/configs/_base_/models/point_rcnn.py b/configs/_base_/models/point_rcnn.py index f381990041..c23a78b55d 100644 --- a/configs/_base_/models/point_rcnn.py +++ b/configs/_base_/models/point_rcnn.py @@ -91,8 +91,11 @@ pos_distance_thr=10.0, rpn=dict( rpn_proposal=dict( - use_rotate_nms=True,score_thr=None, iou_thr=0.8, nms_pre=9000, nms_post=512)), - + use_rotate_nms=True, + score_thr=None, + iou_thr=0.8, + nms_pre=9000, + nms_post=512)), rcnn=dict( assigner=[ dict( # for Pedestrian @@ -136,7 +139,10 @@ cls_neg_thr=0.25)), test_cfg=dict( rpn=dict( - nms_cfg=dict( - use_rotate_nms=True, iou_thr=0.85, nms_pre=9000, nms_post=512, - score_thr=None)), + nms_cfg=dict( + use_rotate_nms=True, + iou_thr=0.85, + nms_pre=9000, + nms_post=512, + score_thr=None)), rcnn=dict(use_rotate_nms=True, nms_thr=0.1, score_thr=0.1))) diff --git a/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py b/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py index f4dce06436..51e31b5e8f 100644 --- a/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py +++ b/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py @@ -67,12 +67,13 @@ ]), dict(type='Pack3DDetInputs', keys=['points']) ] -train_dataloader = dict(batch_size=2, - num_workers=2, - dataset=dict(type='RepeatDataset', - times=2, - dataset=dict(pipeline=train_pipeline, - metainfo=metainfo))) +train_dataloader = dict( + batch_size=2, + num_workers=2, + dataset=dict( + type='RepeatDataset', + times=2, + dataset=dict(pipeline=train_pipeline, metainfo=metainfo))) test_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo)) val_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo)) diff --git a/mmdet3d/models/dense_heads/base_3d_dense_head.py b/mmdet3d/models/dense_heads/base_3d_dense_head.py index 658c951331..2d988bdadc 100644 --- a/mmdet3d/models/dense_heads/base_3d_dense_head.py +++ b/mmdet3d/models/dense_heads/base_3d_dense_head.py @@ -204,7 +204,7 @@ def predict_by_feat(self, score_factors (list[Tensor], optional): Score factor for all scale level, each is a 4D-tensor, has shape (batch_size, num_priors * 1, H, W). Defaults to None. - batch_input_metas (list[dict], Optional): Batch image meta info. + batch_input_metas (list[dict], Optional): Batch inputs meta info. Defaults to None. cfg (ConfigDict, optional): Test / postprocessing configuration, if None, test_cfg would be used. diff --git a/mmdet3d/models/dense_heads/parta2_rpn_head.py b/mmdet3d/models/dense_heads/parta2_rpn_head.py index a81682caff..94557b6767 100644 --- a/mmdet3d/models/dense_heads/parta2_rpn_head.py +++ b/mmdet3d/models/dense_heads/parta2_rpn_head.py @@ -183,8 +183,7 @@ def _predict_by_feat_single(self, result = self.class_agnostic_nms(mlvl_bboxes, mlvl_bboxes_for_nms, mlvl_max_scores, mlvl_label_pred, mlvl_cls_score, mlvl_dir_scores, - score_thr, cfg, - input_meta) + score_thr, cfg, input_meta) return result def loss_and_predict(self, @@ -340,7 +339,7 @@ def class_agnostic_nms(self, mlvl_bboxes: Tensor, labels = torch.cat(labels, dim=0) if bboxes.shape[0] > cfg.nms_post: _, inds = scores.sort(descending=True) - inds = inds[: cfg.nms_post] + inds = inds[:cfg.nms_post] bboxes = bboxes[inds, :] labels = labels[inds] scores = scores[inds] diff --git a/mmdet3d/models/dense_heads/point_rpn_head.py b/mmdet3d/models/dense_heads/point_rpn_head.py index b1c6ef7bb9..be7d7059b1 100644 --- a/mmdet3d/models/dense_heads/point_rpn_head.py +++ b/mmdet3d/models/dense_heads/point_rpn_head.py @@ -1,18 +1,21 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + import torch from mmengine.model import BaseModule -from torch import nn as nn +from mmengine.structures import InstanceData from torch import Tensor +from torch import nn as nn + from mmdet3d.models.layers import nms_bev, nms_normal_bev from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.structures import xywhr2xyxyr from mmdet3d.structures.bbox_3d import (DepthInstance3DBoxes, LiDARInstance3DBoxes) -from mmdet.models.utils import multi_apply -from mmengine.structures import InstanceData -from typing import Dict, Tuple, List from mmdet3d.structures.det3d_data_sample import SampleList from mmdet3d.utils.typing import InstanceList +from mmdet.models.utils import multi_apply + @MODELS.register_module() class PointRPNHead(BaseModule): @@ -36,15 +39,15 @@ class PointRPNHead(BaseModule): """ def __init__(self, - num_classes, - train_cfg, - test_cfg, - pred_layer_cfg=None, - enlarge_width=0.1, - cls_loss=None, - bbox_loss=None, - bbox_coder=None, - init_cfg=None): + num_classes: dict, + train_cfg: dict, + test_cfg: dict, + pred_layer_cfg: dict = None, + enlarge_width: dict = 0.1, + cls_loss: dict = None, + bbox_loss: dict = None, + bbox_coder: dict = None, + init_cfg: dict = None) -> None: super().__init__(init_cfg=init_cfg) self.num_classes = num_classes self.train_cfg = train_cfg @@ -127,12 +130,12 @@ def forward(self, feat_dict: dict) -> Tuple[list]: return point_box_preds, point_cls_preds def loss_by_feat(self, - bbox_preds: List[Tensor], - cls_preds: List[Tensor], - points, - batch_gt_instances_3d, - batch_input_metas=None, - batch_gt_instances_ignore= None): + bbox_preds: List[Tensor], + cls_preds: List[Tensor], + points: List[Tensor], + batch_gt_instances_3d: InstanceList, + batch_input_metas: List[dict] = None, + batch_gt_instances_ignore: InstanceList = None) -> Dict: """Compute loss. Args: @@ -140,10 +143,13 @@ def loss_by_feat(self, cls_preds (dict): Classification from forward of PointRCNN RPN_Head. points (list[torch.Tensor]): Input points. - gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth - bboxes of each sample. - gt_labels_3d (list[torch.Tensor]): Labels of each sample. - img_metas (list[dict], Optional): Contain pcd and img's meta info. + batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of + gt_instances. It usually includes ``bboxes_3d`` and + ``labels_3d`` attributes. + batch_input_metas (list[dict]): Contain pcd and img's meta info. + batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): + Batch of gt_instances_ignore. It includes ``bboxes`` attribute + data that is ignored during training and testing. Defaults to None. Returns: @@ -183,8 +189,12 @@ def get_targets(self, points, batch_gt_instances_3d): Returns: tuple[torch.Tensor]: Targets of PointRCNN RPN head. """ - gt_labels_3d = [instances.labels_3d for instances in batch_gt_instances_3d] - gt_bboxes_3d = [instances.bboxes_3d for instances in batch_gt_instances_3d] + gt_labels_3d = [ + instances.labels_3d for instances in batch_gt_instances_3d + ] + gt_bboxes_3d = [ + instances.bboxes_3d for instances in batch_gt_instances_3d + ] # find empty example for index in range(len(gt_labels_3d)): if len(gt_labels_3d[index]) == 0: @@ -247,25 +257,32 @@ def get_targets_single(self, points, gt_bboxes_3d, gt_labels_3d): return (bbox_targets, mask_targets, positive_mask, negative_mask, point_targets) - def predict_by_feat(self, - points, - bbox_preds, - cls_preds, - batch_input_metas, - cfg, - rescale=False): + def predict_by_feat(self, points: Tensor, bbox_preds: List[Tensor], + cls_preds: List[Tensor], batch_input_metas: List[dict], + cfg: Dict) -> InstanceList: """Generate bboxes from RPN head predictions. Args: points (torch.Tensor): Input points. - bbox_preds (dict): Regression predictions from PointRCNN head. - cls_preds (dict): Class scores predictions from PointRCNN head. - input_metas (list[dict]): Point cloud and image's meta info. - rescale (bool, optional): Whether to rescale bboxes. - Defaults to False. + bbox_preds (list): Regression predictions from PointRCNN head. + cls_preds (list): Class scores predictions from PointRCNN head. + batch_input_metas (list[dict]): Batch inputs meta info. + cfg (ConfigDict, optional): Test / postprocessing + configuration. Returns: - list[tuple[torch.Tensor]]: Bounding boxes, scores and labels. + list[:obj:`InstanceData`]: Detection results of each sample + after the post process. + Each item usually contains following keys. + + - scores_3d (Tensor): Classification scores, has a shape + (num_instances, ) + - labels_3d (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes, + contains a tensor with shape (num_instances, C), where + C >= 7. + - cls_preds (torch.Tensor): Class score of each bbox. """ sem_scores = cls_preds.sigmoid() obj_scores = sem_scores.max(-1)[0] @@ -278,7 +295,9 @@ def predict_by_feat(self, object_class[b]) bbox_selected, score_selected, labels, cls_preds_selected = \ self.class_agnostic_nms(obj_scores[b], sem_scores[b], bbox3d, - points[b, ..., :3], batch_input_metas[b],cfg.nms_cfg) + points[b, ..., :3], + batch_input_metas[b], + cfg.nms_cfg) bbox_selected = batch_input_metas[b]['box_type_3d']( bbox_selected, box_dim=bbox_selected.shape[-1]) result = InstanceData() @@ -289,14 +308,18 @@ def predict_by_feat(self, results.append(result) return results - def class_agnostic_nms(self, obj_scores, sem_scores, bbox, points, - input_meta,nms_cfg): + def class_agnostic_nms(self, obj_scores: Tensor, sem_scores: Tensor, + bbox: Tensor, points: Tensor, input_meta: Dict, + nms_cfg: Dict) -> Tuple[Tensor]: """Class agnostic nms. Args: obj_scores (torch.Tensor): Objectness score of bounding boxes. sem_scores (torch.Tensor): Semantic class score of bounding boxes. bbox (torch.Tensor): Predicted bounding boxes. + points (torch.Tensor): Input points. + input_meta (dict): Contain pcd and img's meta info. + nms_cfg (dict): NMS config dict. Returns: tuple[torch.Tensor]: Bounding boxes, scores and labels. @@ -351,11 +374,11 @@ def class_agnostic_nms(self, obj_scores, sem_scores, bbox, points, labels = torch.argmax(cls_preds, -1) if bbox_selected.shape[0] > nms_cfg.nms_post: _, inds = score_selected.sort(descending=True) - inds = inds[: score_selected.nms_post] + inds = inds[:score_selected.nms_post] bbox_selected = bbox_selected[inds, :] labels = labels[inds] score_selected = score_selected[inds] - cls_preds = cls_preds[inds,:] + cls_preds = cls_preds[inds, :] else: bbox_selected = bbox.tensor score_selected = obj_scores.new_zeros([0]) @@ -424,7 +447,11 @@ def predict(self, feats_dict: Dict, proposal_cfg = self.test_cfg proposal_list = self.predict_by_feat( - stack_points, bbox_preds, cls_preds, cfg=proposal_cfg, batch_input_metas=batch_input_metas) + stack_points, + bbox_preds, + cls_preds, + cfg=proposal_cfg, + batch_input_metas=batch_input_metas) feats_dict['points_cls_preds'] = cls_preds return proposal_list @@ -432,7 +459,7 @@ def loss_and_predict(self, feats_dict: Dict, batch_data_samples: SampleList, proposal_cfg=None, - **kwargs): + **kwargs) -> Tuple[dict, InstanceList]: """Perform forward propagation of the head, then calculate loss and predictions from the features and data samples. @@ -461,11 +488,16 @@ def loss_and_predict(self, stack_points = feats_dict.pop('stack_points') bbox_preds, cls_preds = self(feats_dict) - loss_inputs = (bbox_preds, cls_preds, stack_points) + (batch_gt_instances_3d, batch_input_metas, - batch_gt_instances_ignore) + loss_inputs = (bbox_preds, cls_preds, stack_points) + ( + batch_gt_instances_3d, batch_input_metas, + batch_gt_instances_ignore) losses = self.loss_by_feat(*loss_inputs) predictions = self.predict_by_feat( - stack_points, bbox_preds, cls_preds, batch_input_metas=batch_input_metas, cfg=proposal_cfg) + stack_points, + bbox_preds, + cls_preds, + batch_input_metas=batch_input_metas, + cfg=proposal_cfg) feats_dict['points_cls_preds'] = cls_preds - return losses, predictions \ No newline at end of file + return losses, predictions diff --git a/mmdet3d/models/detectors/point_rcnn.py b/mmdet3d/models/detectors/point_rcnn.py index 967e2243b6..dd014a593b 100644 --- a/mmdet3d/models/detectors/point_rcnn.py +++ b/mmdet3d/models/detectors/point_rcnn.py @@ -1,9 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Optional + import torch from mmdet3d.registry import MODELS from .two_stage import TwoStage3DDetector -from typing import Optional,Dict + + @MODELS.register_module() class PointRCNN(TwoStage3DDetector): r"""PointRCNN detector. @@ -38,7 +41,7 @@ def __init__(self, train_cfg=train_cfg, test_cfg=test_cfg, init_cfg=init_cfg, - data_preprocessor=data_preprocessor) + data_preprocessor=data_preprocessor) def extract_feat(self, batch_inputs_dict: Dict) -> Dict: """Directly extract features from the backbone+neck. @@ -58,4 +61,7 @@ def extract_feat(self, batch_inputs_dict: Dict) -> Dict: if self.with_neck: x = self.neck(x) - return dict(features=x['fp_features'].clone(),points=x['fp_xyz'].clone(),stack_points=points) + return dict( + features=x['fp_features'].clone(), + points=x['fp_xyz'].clone(), + stack_points=points) diff --git a/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py index 0bc7d4f5a4..9d2d17bef5 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py +++ b/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py @@ -10,6 +10,7 @@ from mmdet3d.models import make_sparse_convmodule from mmdet3d.models.layers.spconv import IS_SPCONV2_AVAILABLE +from mmdet3d.utils.typing import InstanceList from mmdet.models.utils import multi_apply if IS_SPCONV2_AVAILABLE: @@ -514,7 +515,7 @@ def get_results(self, class_labels: Tensor, class_pred: Tensor, input_metas: List[dict], - cfg: dict = None) -> List: + cfg: dict = None) -> InstanceList: """Generate bboxes from bbox head predictions. Args: @@ -527,7 +528,17 @@ def get_results(self, cfg (:obj:`ConfigDict`): Testing config. Returns: - list[tuple]: Decoded bbox, scores and labels after nms. + list[:obj:`InstanceData`]: Detection results of each sample + after the post process. + Each item usually contains following keys. + + - scores_3d (Tensor): Classification scores, has a shape + (num_instances, ) + - labels_3d (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes, + contains a tensor with shape (num_instances, C), where + C >= 7. """ roi_batch_id = rois[..., 0] roi_boxes = rois[..., 1:] # boxes without batch id diff --git a/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py index 7207489700..b52c58f1b6 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py +++ b/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py @@ -1,9 +1,13 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Tuple + import numpy as np import torch from mmcv.cnn import ConvModule from mmcv.cnn.bricks import build_conv_layer from mmengine.model import BaseModule, normal_init +from mmengine.structures import InstanceData +from torch import Tensor from torch import nn as nn from mmdet3d.models.layers import nms_bev, nms_normal_bev @@ -11,8 +15,9 @@ from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.structures.bbox_3d import (LiDARInstance3DBoxes, rotation_3d_in_axis, xywhr2xyxyr) +from mmdet3d.utils.typing import InstanceList from mmdet.models.utils import multi_apply -from mmengine.structures import InstanceData + @MODELS.register_module() class PointRCNNBboxHead(BaseModule): @@ -61,34 +66,35 @@ class PointRCNNBboxHead(BaseModule): init_cfg (dict, optional): Config of initialization. Defaults to None. """ - def __init__( - self, - num_classes, - in_channels, - mlp_channels, - pred_layer_cfg=None, - num_points=(128, 32, -1), - radius=(0.2, 0.4, 100), - num_samples=(64, 64, 64), - sa_channels=((128, 128, 128), (128, 128, 256), (256, 256, 512)), - bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), - sa_cfg=dict(type='PointSAModule', pool_mod='max', use_xyz=True), - conv_cfg=dict(type='Conv1d'), - norm_cfg=dict(type='BN1d'), - act_cfg=dict(type='ReLU'), - bias='auto', - loss_bbox=dict( - type='SmoothL1Loss', - beta=1.0 / 9.0, - reduction='sum', - loss_weight=1.0), - loss_cls=dict( - type='CrossEntropyLoss', - use_sigmoid=True, - reduction='sum', - loss_weight=1.0), - with_corner_loss=True, - init_cfg=None): + def __init__(self, + num_classes: dict, + in_channels: dict, + mlp_channels: dict, + pred_layer_cfg: dict = None, + num_points: dict = (128, 32, -1), + radius: dict = (0.2, 0.4, 100), + num_samples: dict = (64, 64, 64), + sa_channels: dict = ((128, 128, 128), (128, 128, 256), + (256, 256, 512)), + bbox_coder: dict = dict(type='DeltaXYZWLHRBBoxCoder'), + sa_cfg: dict = dict( + type='PointSAModule', pool_mod='max', use_xyz=True), + conv_cfg: dict = dict(type='Conv1d'), + norm_cfg: dict = dict(type='BN1d'), + act_cfg: dict = dict(type='ReLU'), + bias: str = 'auto', + loss_bbox: dict = dict( + type='SmoothL1Loss', + beta=1.0 / 9.0, + reduction='sum', + loss_weight=1.0), + loss_cls: dict = dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0), + with_corner_loss: bool = True, + init_cfg: dict = None) -> None: super(PointRCNNBboxHead, self).__init__(init_cfg=init_cfg) self.num_classes = num_classes self.num_sa = len(sa_channels) @@ -203,7 +209,7 @@ def init_weights(self): nn.init.constant_(m.bias, 0) normal_init(self.conv_reg.weight, mean=0, std=0.001) - def forward(self, feats): + def forward(self, feats: Tensor) -> Tuple[Tensor]: """Forward pass. Args: @@ -239,8 +245,10 @@ def forward(self, feats): rcnn_reg = rcnn_reg.transpose(1, 2).contiguous().squeeze(dim=1) return rcnn_cls, rcnn_reg - def loss(self, cls_score, bbox_pred, rois, labels, bbox_targets, - pos_gt_bboxes, reg_mask, label_weights, bbox_weights): + def loss(self, cls_score: Tensor, bbox_pred: Tensor, rois: Tensor, + labels: Tensor, bbox_targets: Tensor, pos_gt_bboxes: Tensor, + reg_mask: Tensor, label_weights: Tensor, + bbox_weights: Tensor) -> Dict: """Computing losses. Args: @@ -450,12 +458,12 @@ def _get_target_single(self, pos_bboxes, pos_gt_bboxes, ious, cfg): bbox_weights) def get_results(self, - rois, - cls_score, - bbox_pred, - class_labels, - input_metas, - cfg=None): + rois: Tensor, + cls_score: Tensor, + bbox_pred: Tensor, + class_labels: Tensor, + input_metas: List[dict], + cfg: dict = None) -> InstanceList: """Generate bboxes from bbox head predictions. Args: @@ -463,12 +471,22 @@ def get_results(self, cls_score (torch.Tensor): Scores of bounding boxes. bbox_pred (torch.Tensor): Bounding boxes predictions class_labels (torch.Tensor): Label of classes - img_metas (list[dict]): Point cloud and image's meta info. + input_metas (list[dict]): Point cloud and image's meta info. cfg (:obj:`ConfigDict`, optional): Testing config. Defaults to None. Returns: - list[tuple]: Decoded bbox, scores and labels after nms. + list[:obj:`InstanceData`]: Detection results of each sample + after the post process. + Each item usually contains following keys. + + - scores_3d (Tensor): Classification scores, has a shape + (num_instances, ) + - labels_3d (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes, + contains a tensor with shape (num_instances, C), where + C >= 7. """ roi_batch_id = rois[..., 0] roi_boxes = rois[..., 1:] # boxes without batch id diff --git a/mmdet3d/models/roi_heads/part_aggregation_roi_head.py b/mmdet3d/models/roi_heads/part_aggregation_roi_head.py index aed8f2e4c5..92544fc149 100644 --- a/mmdet3d/models/roi_heads/part_aggregation_roi_head.py +++ b/mmdet3d/models/roi_heads/part_aggregation_roi_head.py @@ -89,10 +89,9 @@ def _bbox_forward_train(self, feats_dict: Dict, voxels_dict: Dict, bbox_results.update(loss_bbox=loss_bbox) return bbox_results - def _assign_and_sample( - self, proposal_list: InstanceList, - batch_gt_instances_3d: InstanceList, - batch_gt_instances_ignore) -> List[SamplingResult]: + def _assign_and_sample(self, proposal_list: InstanceList, + batch_gt_instances_3d: InstanceList, + batch_gt_instances_ignore) -> List[SamplingResult]: """Assign and sample proposals for training. Args: @@ -116,8 +115,7 @@ def _assign_and_sample( cur_gt_instances_ignore = batch_gt_instances_ignore[batch_idx] cur_gt_instances_3d.bboxes_3d = cur_gt_instances_3d.\ bboxes_3d.tensor - cur_gt_bboxes = cur_gt_instances_3d.bboxes_3d.to( - cur_boxes.device) + cur_gt_bboxes = cur_gt_instances_3d.bboxes_3d.to(cur_boxes.device) cur_gt_labels = cur_gt_instances_3d.labels_3d batch_num_gts = 0 @@ -161,7 +159,8 @@ def _assign_and_sample( batch_gt_labels) else: # for single class assign_result = self.bbox_assigner.assign( - cur_proposal_list, cur_gt_instances_3d,cur_gt_instances_ignore) + cur_proposal_list, cur_gt_instances_3d, + cur_gt_instances_ignore) # sample boxes sampling_result = self.bbox_sampler.sample(assign_result, cur_boxes.tensor, diff --git a/mmdet3d/models/roi_heads/point_rcnn_roi_head.py b/mmdet3d/models/roi_heads/point_rcnn_roi_head.py index ace2882121..38864c3d8d 100644 --- a/mmdet3d/models/roi_heads/point_rcnn_roi_head.py +++ b/mmdet3d/models/roi_heads/point_rcnn_roi_head.py @@ -1,9 +1,12 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict + import torch from torch.nn import functional as F from mmdet3d.registry import MODELS, TASK_UTILS -from mmdet3d.structures import bbox3d2result, bbox3d2roi +from mmdet3d.structures import bbox3d2roi +from mmdet3d.utils.typing import InstanceList, SampleList from mmdet.models.task_modules import AssignResult from .base_3droi_head import Base3DRoIHead @@ -56,26 +59,21 @@ def init_assigner_sampler(self): ] self.bbox_sampler = TASK_UTILS.build(self.train_cfg.sampler) - def loss(self, feats_dict, rpn_results_list, - batch_data_samples, **kwargs) -> dict: - """Training forward function of PointRCNNRoIHead. + def loss(self, feats_dict: Dict, rpn_results_list: InstanceList, + batch_data_samples: SampleList, **kwargs) -> dict: + """Perform forward propagation and loss calculation of the detection + roi on the features of the upstream network. Args: feats_dict (dict): Contains features from the first stage. - imput_metas (list[dict]): Meta info of each input. - proposal_list (list[dict]): Proposal information from rpn. - The dictionary should contain the following keys: - - - boxes_3d (:obj:`BaseInstance3DBoxes`): Proposal bboxes - - labels_3d (torch.Tensor): Labels of proposals - gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): - GT bboxes of each sample. The bboxes are encapsulated - by 3D box bboxes_3d. - gt_labels_3d (list[LongTensor]): GT labels of each sample. + rpn_results_list (List[:obj:`InstancesData`]): Detection results + of rpn head. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + samples. It usually includes information such as + `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`. Returns: - dict: Losses from RoI RCNN head. - - loss_bbox (torch.Tensor): Loss of bboxes + dict[str, Tensor]: A dictionary of loss components """ features = feats_dict['features'] points = feats_dict['points'] @@ -90,7 +88,9 @@ def loss(self, feats_dict, rpn_results_list, batch_gt_instances_ignore.append(data_sample.ignored_instances) else: batch_gt_instances_ignore.append(None) - sample_results = self._assign_and_sample(rpn_results_list, batch_gt_instances_3d,batch_gt_instances_ignore) + sample_results = self._assign_and_sample(rpn_results_list, + batch_gt_instances_3d, + batch_gt_instances_ignore) # concat the depth, semantic features and backbone features features = features.transpose(1, 2).contiguous() @@ -109,25 +109,39 @@ def loss(self, feats_dict, rpn_results_list, return losses def predict(self, - feats_dict, - rpn_results_list, - batch_data_samples, + feats_dict: Dict, + rpn_results_list: InstanceList, + batch_data_samples: SampleList, rescale: bool = False, - **kwargs): - """Simple testing forward function of PointRCNNRoIHead. - - Note: - This function assumes that the batch size is 1 + **kwargs) -> InstanceList: + """Perform forward propagation of the roi head and predict detection + results on the features of the upstream network. Args: feats_dict (dict): Contains features from the first stage. - img_metas (list[dict]): Meta info of each image. - proposal_list (list[dict]): Proposal information from rpn. + rpn_results_list (List[:obj:`InstancesData`]): Detection results + of rpn head. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + samples. It usually includes information such as + `gt_instance_3d`, `gt_panoptic_seg_3d` and `gt_sem_seg_3d`. + rescale (bool): If True, return boxes in original image space. + Defaults to False. Returns: - dict: Bbox results of one frame. + list[:obj:`InstanceData`]: Detection results of each sample + after the post process. + Each item usually contains following keys. + + - scores_3d (Tensor): Classification scores, has a shape + (num_instances, ) + - labels_3d (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bboxes_3d (BaseInstance3DBoxes): Prediction of bboxes, + contains a tensor with shape (num_instances, C), where + C >= 7. """ - rois = bbox3d2roi([res['bboxes_3d'].tensor for res in rpn_results_list]) + rois = bbox3d2roi( + [res['bboxes_3d'].tensor for res in rpn_results_list]) labels_3d = [res['labels_3d'] for res in rpn_results_list] batch_input_metas = [ data_samples.metainfo for data_samples in batch_data_samples @@ -201,13 +215,14 @@ def _bbox_forward(self, features, points, batch_size, rois): features of roi_extractor. """ pooled_point_feats = self.bbox_roi_extractor(features, points, - batch_size, rois) + batch_size, rois) cls_score, bbox_pred = self.bbox_head(pooled_point_feats) bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred) return bbox_results - def _assign_and_sample(self, rpn_results_list, batch_gt_instances_3d,batch_gt_instances_ignore): + def _assign_and_sample(self, rpn_results_list, batch_gt_instances_3d, + batch_gt_instances_ignore): """Assign and sample proposals for training. Args: @@ -273,7 +288,8 @@ def _assign_and_sample(self, rpn_results_list, batch_gt_instances_3d,batch_gt_in batch_gt_labels) else: # for single class assign_result = self.bbox_assigner.assign( - cur_proposal_list, cur_gt_instances_3d, cur_gt_instances_ignore) + cur_proposal_list, cur_gt_instances_3d, + cur_gt_instances_ignore) # sample boxes sampling_result = self.bbox_sampler.sample(assign_result, diff --git a/mmdet3d/utils/typing.py b/mmdet3d/utils/typing.py index 9292b57c26..5489fe313a 100644 --- a/mmdet3d/utils/typing.py +++ b/mmdet3d/utils/typing.py @@ -5,6 +5,7 @@ from mmengine.config import ConfigDict from mmengine.structures import InstanceData +from mmdet3d.structures.det3d_data_sample import Det3DDataSample from mmdet.models.task_modules.samplers import SamplingResult # Type hint of config data @@ -21,3 +22,4 @@ SamplingResultList = List[SamplingResult] OptSamplingResultList = Optional[SamplingResultList] +SampleList = List[Det3DDataSample] diff --git a/tests/test_models/test_detectors/test_pointrcnn.py b/tests/test_models/test_detectors/test_pointrcnn.py index 2e7c0cfe0c..293a6eb89e 100644 --- a/tests/test_models/test_detectors/test_pointrcnn.py +++ b/tests/test_models/test_detectors/test_pointrcnn.py @@ -21,8 +21,7 @@ def test_pointrcnn(self): model = MODELS.build(pointrcnn_cfg) num_gt_instance = 2 packed_inputs = _create_detector_inputs( - num_points=10101, - num_gt_instance=num_gt_instance) + num_points=10101, num_gt_instance=num_gt_instance) if torch.cuda.is_available(): model = model.cuda() From 77bc0d128e3ea83d71c7822237b23f0da87b7fcd Mon Sep 17 00:00:00 2001 From: VVsssssk Date: Tue, 13 Sep 2022 11:14:56 +0800 Subject: [PATCH 08/16] fix --- .../point-rcnn_8xb2_kitti-3d-3class.py | 42 +++++++++++++++++++ .../bbox_heads/point_rcnn_bbox_head.py | 9 ++-- 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py b/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py index 51e31b5e8f..d7fbc1ccd0 100644 --- a/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py +++ b/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py @@ -86,3 +86,45 @@ # or not by default. # - `base_batch_size` = (8 GPUs) x (2 samples per GPU). auto_scale_lr = dict(enable=False, base_batch_size=16) +param_scheduler = [ + # learning rate scheduler + # During the first 16 epochs, learning rate increases from 0 to lr * 10 + # during the next 24 epochs, learning rate decreases from lr * 10 to + # lr * 1e-4 + dict( + type='CosineAnnealingLR', + T_max=35, + eta_min=lr * 10, + begin=0, + end=35, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=45, + eta_min=lr * 1e-4, + begin=35, + end=80, + by_epoch=True, + convert_to_iter_based=True), + # momentum scheduler + # During the first 16 epochs, momentum increases from 0 to 0.85 / 0.95 + # during the next 24 epochs, momentum increases from 0.85 / 0.95 to 1 + dict( + type='CosineAnnealingMomentum', + T_max=35, + eta_min=0.85 / 0.95, + begin=0, + end=35, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='CosineAnnealingMomentum', + T_max=45, + eta_min=1, + begin=35, + end=80, + by_epoch=True, + convert_to_iter_based=True) +] +load_from = 'work_dirs/pointrcnn/epoch_34.pth' \ No newline at end of file diff --git a/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py index b52c58f1b6..8fd65b49e6 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py +++ b/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py @@ -348,10 +348,11 @@ def get_corner_loss_lidar(self, pred_bbox3d, gt_bbox3d, delta=1.0): torch.norm(pred_box_corners - gt_box_corners_flip, dim=2)) # huber loss abs_error = corner_dist.abs() - quadratic = abs_error.clamp(max=delta) - linear = (abs_error - quadratic) - corner_loss = 0.5 * quadratic**2 + delta * linear - return corner_loss.mean(dim=1) + # quadratic = abs_error.clamp(max=delta) + # linear = (abs_error - quadratic) + # corner_loss = 0.5 * quadratic**2 + delta * linear + loss = torch.where(abs_error < delta, 0.5 * abs_error ** 2 / delta, abs_error - 0.5 * delta) + return loss.mean() def get_targets(self, sampling_results, rcnn_train_cfg, concat=True): """Generate targets. From 12faea143181816d3304a5019314ce560bb60593 Mon Sep 17 00:00:00 2001 From: VVsssssk Date: Tue, 13 Sep 2022 14:38:53 +0800 Subject: [PATCH 09/16] fix --- mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py | 4 ++-- .../models/roi_heads/bbox_heads/point_rcnn_bbox_head.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py index 9d2d17bef5..c56384b8f0 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py +++ b/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py @@ -329,9 +329,9 @@ def loss(self, cls_score, bbox_pred, rois, labels, bbox_targets, pos_inds = (reg_mask > 0) if pos_inds.any() == 0: # fake a part loss - losses['loss_bbox'] = loss_cls.new_tensor(0) + losses['loss_bbox'] = loss_cls.new_tensor(0) * loss_cls.sum() if self.with_corner_loss: - losses['loss_corner'] = loss_cls.new_tensor(0) + losses['loss_corner'] = loss_cls.new_tensor(0) * loss_cls.sum() else: pos_bbox_pred = bbox_pred.view(rcnn_batch_size, -1)[pos_inds] bbox_weights_flat = bbox_weights[pos_inds].view(-1, 1).repeat( diff --git a/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py index 8fd65b49e6..b842d66d36 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py +++ b/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py @@ -310,12 +310,11 @@ def loss(self, cls_score: Tensor, bbox_pred: Tensor, rois: Tensor, # calculate corner loss loss_corner = self.get_corner_loss_lidar(pred_boxes3d, - pos_gt_bboxes) + pos_gt_bboxes).mean() losses['loss_corner'] = loss_corner else: - losses['loss_corner'] = loss_cls.new_tensor(0) - + losses['loss_corner'] = loss_cls.new_tensor(0) * loss_cls.sum() return losses def get_corner_loss_lidar(self, pred_bbox3d, gt_bbox3d, delta=1.0): @@ -352,7 +351,7 @@ def get_corner_loss_lidar(self, pred_bbox3d, gt_bbox3d, delta=1.0): # linear = (abs_error - quadratic) # corner_loss = 0.5 * quadratic**2 + delta * linear loss = torch.where(abs_error < delta, 0.5 * abs_error ** 2 / delta, abs_error - 0.5 * delta) - return loss.mean() + return loss.mean(dim=1) def get_targets(self, sampling_results, rcnn_train_cfg, concat=True): """Generate targets. From 6363b70419d3c1bf38e885c61419f49c951ad3cf Mon Sep 17 00:00:00 2001 From: VVsssssk Date: Fri, 16 Sep 2022 11:27:14 +0800 Subject: [PATCH 10/16] fix --- .../point-rcnn_8xb2_kitti-3d-3class.py | 9 +- docs/en/advanced_guides/customize_models.md | 6 +- mmdet3d/models/dense_heads/point_rpn_head.py | 40 ++++---- mmdet3d/models/detectors/point_rcnn.py | 6 +- .../roi_heads/bbox_heads/parta2_bbox_head.py | 76 ++++++++------- .../bbox_heads/point_rcnn_bbox_head.py | 39 +++++--- .../mask_heads/pointwise_semantic_head.py | 48 +++++---- .../roi_heads/mask_heads/primitive_head.py | 97 +++++++++++-------- .../roi_heads/part_aggregation_roi_head.py | 24 ++--- .../models/roi_heads/point_rcnn_roi_head.py | 59 ++++++----- .../single_roiaware_extractor.py | 9 +- .../single_roipoint_extractor.py | 10 +- 12 files changed, 243 insertions(+), 180 deletions(-) diff --git a/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py b/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py index d7fbc1ccd0..f0e7c96e1f 100644 --- a/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py +++ b/configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py @@ -88,8 +88,8 @@ auto_scale_lr = dict(enable=False, base_batch_size=16) param_scheduler = [ # learning rate scheduler - # During the first 16 epochs, learning rate increases from 0 to lr * 10 - # during the next 24 epochs, learning rate decreases from lr * 10 to + # During the first 35 epochs, learning rate increases from 0 to lr * 10 + # during the next 45 epochs, learning rate decreases from lr * 10 to # lr * 1e-4 dict( type='CosineAnnealingLR', @@ -108,8 +108,8 @@ by_epoch=True, convert_to_iter_based=True), # momentum scheduler - # During the first 16 epochs, momentum increases from 0 to 0.85 / 0.95 - # during the next 24 epochs, momentum increases from 0.85 / 0.95 to 1 + # During the first 35 epochs, momentum increases from 0 to 0.85 / 0.95 + # during the next 45 epochs, momentum increases from 0.85 / 0.95 to 1 dict( type='CosineAnnealingMomentum', T_max=35, @@ -127,4 +127,3 @@ by_epoch=True, convert_to_iter_based=True) ] -load_from = 'work_dirs/pointrcnn/epoch_34.pth' \ No newline at end of file diff --git a/docs/en/advanced_guides/customize_models.md b/docs/en/advanced_guides/customize_models.md index f5876e3747..18fc8bb743 100644 --- a/docs/en/advanced_guides/customize_models.md +++ b/docs/en/advanced_guides/customize_models.md @@ -365,7 +365,7 @@ class PartAggregationROIHead(Base3DRoIHead): Args: feats_dict (dict): Contains features from the first stage. - rpn_results_list (List[:obj:`InstancesData`]): Detection results + rpn_results_list (List[:obj:`InstanceData`]): Detection results of rpn head. batch_data_samples (List[:obj:`Det3DDataSample`]): The Data samples. It usually includes information such as @@ -412,7 +412,7 @@ class PartAggregationROIHead(Base3DRoIHead): voxel_dict (dict): Contains information of voxels. batch_input_metas (list[dict], Optional): Batch image meta info. Defaults to None. - rpn_results_list (List[:obj:`InstancesData`]): Detection results + rpn_results_list (List[:obj:`InstanceData`]): Detection results of rpn head. test_cfg (Config): Test config. @@ -438,7 +438,7 @@ class PartAggregationROIHead(Base3DRoIHead): Args: feats_dict (dict): Contains features from the first stage. - rpn_results_list (List[:obj:`InstancesData`]): Detection results + rpn_results_list (List[:obj:`InstanceData`]): Detection results of rpn head. batch_data_samples (List[:obj:`Det3DDataSample`]): The Data samples. It usually includes information such as diff --git a/mmdet3d/models/dense_heads/point_rpn_head.py b/mmdet3d/models/dense_heads/point_rpn_head.py index be7d7059b1..97a7a030c7 100644 --- a/mmdet3d/models/dense_heads/point_rpn_head.py +++ b/mmdet3d/models/dense_heads/point_rpn_head.py @@ -10,7 +10,8 @@ from mmdet3d.models.layers import nms_bev, nms_normal_bev from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.structures import xywhr2xyxyr -from mmdet3d.structures.bbox_3d import (DepthInstance3DBoxes, +from mmdet3d.structures.bbox_3d import (BaseInstance3DBoxes, + DepthInstance3DBoxes, LiDARInstance3DBoxes) from mmdet3d.structures.det3d_data_sample import SampleList from mmdet3d.utils.typing import InstanceList @@ -72,7 +73,8 @@ def __init__(self, input_channels=pred_layer_cfg.in_channels, output_channels=self._get_reg_out_channels()) - def _make_fc_layers(self, fc_cfg, input_channels, output_channels): + def _make_fc_layers(self, fc_cfg: dict, input_channels: int, + output_channels: int) -> nn.Sequential: """Make fully connect layers. Args: @@ -107,7 +109,7 @@ def _get_reg_out_channels(self): # torch.cos(yaw) (1), torch.sin(yaw) (1) return self.bbox_coder.code_size - def forward(self, feat_dict: dict) -> Tuple[list]: + def forward(self, feat_dict: dict) -> Tuple[List[Tensor]]: """Forward pass. Args: @@ -117,7 +119,7 @@ def forward(self, feat_dict: dict) -> Tuple[list]: tuple[list[torch.Tensor]]: Predicted boxes and classification scores. """ - point_features = feat_dict['features'] + point_features = feat_dict['fp_features'] point_features = point_features.permute(0, 2, 1).contiguous() batch_size = point_features.shape[0] feat_cls = point_features.view(-1, point_features.shape[-1]) @@ -177,14 +179,15 @@ def loss_by_feat(self, return losses - def get_targets(self, points, batch_gt_instances_3d): + def get_targets(self, points: List[Tensor], + batch_gt_instances_3d: InstanceList) -> Tuple[Tensor]: """Generate targets of PointRCNN RPN head. Args: points (list[torch.Tensor]): Points of each batch. - gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth - bboxes of each batch. - gt_labels_3d (list[torch.Tensor]): Labels of each batch. + batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of + gt_instances. It usually includes ``bboxes_3d`` and + ``labels_3d`` attributes. Returns: tuple[torch.Tensor]: Targets of PointRCNN RPN head. @@ -216,7 +219,9 @@ def get_targets(self, points, batch_gt_instances_3d): return (bbox_targets, mask_targets, positive_mask, negative_mask, box_loss_weights, point_targets) - def get_targets_single(self, points, gt_bboxes_3d, gt_labels_3d): + def get_targets_single(self, points: Tensor, + gt_bboxes_3d: BaseInstance3DBoxes, + gt_labels_3d: Tensor) -> Tuple[Tensor]: """Generate targets of PointRCNN RPN head for single batch. Args: @@ -386,7 +391,8 @@ def class_agnostic_nms(self, obj_scores: Tensor, sem_scores: Tensor, cls_preds = obj_scores.new_zeros([0, sem_scores.shape[-1]]) return bbox_selected, score_selected, labels, cls_preds - def _assign_targets_by_points_inside(self, bboxes_3d, points): + def _assign_targets_by_points_inside(self, bboxes_3d: BaseInstance3DBoxes, + points: Tensor) -> Tuple[Tensor]: """Compute assignment by checking whether point is inside bbox. Args: @@ -442,12 +448,12 @@ def predict(self, feats_dict: Dict, batch_input_metas = [ data_samples.metainfo for data_samples in batch_data_samples ] - stack_points = feats_dict.pop('stack_points') + raw_points = feats_dict.pop('raw_points') bbox_preds, cls_preds = self(feats_dict) proposal_cfg = self.test_cfg proposal_list = self.predict_by_feat( - stack_points, + raw_points, bbox_preds, cls_preds, cfg=proposal_cfg, @@ -485,16 +491,16 @@ def loss_and_predict(self, batch_gt_instances_3d.append(data_sample.gt_instances_3d) batch_gt_instances_ignore.append( data_sample.get('ignored_instances', None)) - stack_points = feats_dict.pop('stack_points') + raw_points = feats_dict.pop('raw_points') bbox_preds, cls_preds = self(feats_dict) - loss_inputs = (bbox_preds, cls_preds, stack_points) + ( - batch_gt_instances_3d, batch_input_metas, - batch_gt_instances_ignore) + loss_inputs = (bbox_preds, cls_preds, + raw_points) + (batch_gt_instances_3d, batch_input_metas, + batch_gt_instances_ignore) losses = self.loss_by_feat(*loss_inputs) predictions = self.predict_by_feat( - stack_points, + raw_points, bbox_preds, cls_preds, batch_input_metas=batch_input_metas, diff --git a/mmdet3d/models/detectors/point_rcnn.py b/mmdet3d/models/detectors/point_rcnn.py index dd014a593b..359ba0a067 100644 --- a/mmdet3d/models/detectors/point_rcnn.py +++ b/mmdet3d/models/detectors/point_rcnn.py @@ -62,6 +62,6 @@ def extract_feat(self, batch_inputs_dict: Dict) -> Dict: if self.with_neck: x = self.neck(x) return dict( - features=x['fp_features'].clone(), - points=x['fp_xyz'].clone(), - stack_points=points) + fp_features=x['fp_features'].clone(), + fp_points=x['fp_xyz'].clone(), + raw_points=points) diff --git a/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py index c56384b8f0..1f1d328fbd 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py +++ b/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List +from typing import Dict, List, Tuple import numpy as np import torch @@ -26,6 +26,7 @@ from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.structures.bbox_3d import (LiDARInstance3DBoxes, rotation_3d_in_axis, xywhr2xyxyr) +from mmdet3d.utils.typing import SamplingResultList @MODELS.register_module() @@ -56,34 +57,34 @@ class PartA2BboxHead(BaseModule): conv_cfg (dict): Config dict of convolutional layers norm_cfg (dict): Config dict of normalization layers loss_bbox (dict): Config dict of box regression loss. - loss_cls (dict): Config dict of classifacation loss. + loss_cls (dict, optional): Config dict of classifacation loss. """ def __init__(self, - num_classes, - seg_in_channels, - part_in_channels, - seg_conv_channels=None, - part_conv_channels=None, - merge_conv_channels=None, - down_conv_channels=None, - shared_fc_channels=None, - cls_channels=None, - reg_channels=None, - dropout_ratio=0.1, - roi_feat_size=14, - with_corner_loss=True, - bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), - conv_cfg=dict(type='Conv1d'), - norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), - loss_bbox=dict( + num_classes: int, + seg_in_channels: int, + part_in_channels: int, + seg_conv_channels: List[int] = None, + part_conv_channels: List[int] = None, + merge_conv_channels: List[int] = None, + down_conv_channels: List[int] = None, + shared_fc_channels: List[int] = None, + cls_channels: List[int] = None, + reg_channels: List[int] = None, + dropout_ratio: float = 0.1, + roi_feat_size: int = 14, + with_corner_loss: bool = True, + bbox_coder: dict = dict(type='DeltaXYZWLHRBBoxCoder'), + conv_cfg: dict = dict(type='Conv1d'), + norm_cfg: dict = dict(type='BN1d', eps=1e-3, momentum=0.01), + loss_bbox: dict = dict( type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=2.0), - loss_cls=dict( + loss_cls: dict = dict( type='CrossEntropyLoss', use_sigmoid=True, reduction='none', loss_weight=1.0), - init_cfg=None): + init_cfg: dict = None) -> None: super(PartA2BboxHead, self).__init__(init_cfg=init_cfg) self.num_classes = num_classes self.with_corner_loss = with_corner_loss @@ -244,7 +245,7 @@ def init_weights(self): super().init_weights() normal_init(self.conv_reg[-1].conv, mean=0, std=0.001) - def forward(self, seg_feats, part_feats): + def forward(self, seg_feats: Tensor, part_feats: Tensor) -> Tuple[Tensor]: """Forward pass. Args: @@ -294,8 +295,10 @@ def forward(self, seg_feats, part_feats): return cls_score, bbox_pred - def loss(self, cls_score, bbox_pred, rois, labels, bbox_targets, - pos_gt_bboxes, reg_mask, label_weights, bbox_weights): + def loss(self, cls_score: Tensor, bbox_pred: Tensor, rois: Tensor, + labels: Tensor, bbox_targets: Tensor, pos_gt_bboxes: Tensor, + reg_mask: Tensor, label_weights: Tensor, + bbox_weights: Tensor) -> Dict: """Computing losses. Args: @@ -367,7 +370,10 @@ def loss(self, cls_score, bbox_pred, rois, labels, bbox_targets, return losses - def get_targets(self, sampling_results, rcnn_train_cfg, concat=True): + def get_targets(self, + sampling_results: SamplingResultList, + rcnn_train_cfg: dict, + concat: bool = True) -> Tuple[Tensor]: """Generate targets. Args: @@ -407,7 +413,8 @@ def get_targets(self, sampling_results, rcnn_train_cfg, concat=True): return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, bbox_weights) - def _get_target_single(self, pos_bboxes, pos_gt_bboxes, ious, cfg): + def _get_target_single(self, pos_bboxes: Tensor, pos_gt_bboxes: Tensor, + ious: Tensor, cfg: dict) -> Tuple[Tensor]: """Generate training targets for a single sample. Args: @@ -472,7 +479,10 @@ def _get_target_single(self, pos_bboxes, pos_gt_bboxes, ious, cfg): return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, bbox_weights) - def get_corner_loss_lidar(self, pred_bbox3d, gt_bbox3d, delta=1.0): + def get_corner_loss_lidar(self, + pred_bbox3d: Tensor, + gt_bbox3d: Tensor, + delta: float = 1.0) -> Tensor: """Calculate corner loss of given boxes. Args: @@ -580,12 +590,12 @@ def get_results(self, return result_list def multi_class_nms(self, - box_probs, - box_preds, - score_thr, - nms_thr, - input_meta, - use_rotate_nms=True): + box_probs: Tensor, + box_preds: Tensor, + score_thr: float, + nms_thr: float, + input_meta: dict, + use_rotate_nms: bool = True) -> Tensor: """Multi-class NMS for box head. Note: diff --git a/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py index b842d66d36..43a288dd1c 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py +++ b/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py @@ -3,19 +3,19 @@ import numpy as np import torch +import torch.nn as nn from mmcv.cnn import ConvModule from mmcv.cnn.bricks import build_conv_layer from mmengine.model import BaseModule, normal_init from mmengine.structures import InstanceData from torch import Tensor -from torch import nn as nn from mmdet3d.models.layers import nms_bev, nms_normal_bev from mmdet3d.models.layers.pointnet_modules import build_sa_module from mmdet3d.registry import MODELS, TASK_UTILS from mmdet3d.structures.bbox_3d import (LiDARInstance3DBoxes, rotation_3d_in_axis, xywhr2xyxyr) -from mmdet3d.utils.typing import InstanceList +from mmdet3d.utils.typing import InstanceList, SamplingResultList from mmdet.models.utils import multi_apply @@ -175,7 +175,8 @@ def __init__(self, if init_cfg is None: self.init_cfg = dict(type='Xavier', layer=['Conv2d', 'Conv1d']) - def _add_conv_branch(self, in_channels, conv_channels): + def _add_conv_branch(self, in_channels: int, + conv_channels: tuple) -> nn.Sequential: """Add shared or separable branch. Args: @@ -317,7 +318,10 @@ def loss(self, cls_score: Tensor, bbox_pred: Tensor, rois: Tensor, losses['loss_corner'] = loss_cls.new_tensor(0) * loss_cls.sum() return losses - def get_corner_loss_lidar(self, pred_bbox3d, gt_bbox3d, delta=1.0): + def get_corner_loss_lidar(self, + pred_bbox3d: Tensor, + gt_bbox3d: Tensor, + delta: float = 1.0) -> Tensor: """Calculate corner loss of given boxes. Args: @@ -350,17 +354,21 @@ def get_corner_loss_lidar(self, pred_bbox3d, gt_bbox3d, delta=1.0): # quadratic = abs_error.clamp(max=delta) # linear = (abs_error - quadratic) # corner_loss = 0.5 * quadratic**2 + delta * linear - loss = torch.where(abs_error < delta, 0.5 * abs_error ** 2 / delta, abs_error - 0.5 * delta) + loss = torch.where(abs_error < delta, 0.5 * abs_error**2 / delta, + abs_error - 0.5 * delta) return loss.mean(dim=1) - def get_targets(self, sampling_results, rcnn_train_cfg, concat=True): + def get_targets(self, + sampling_results: SamplingResultList, + rcnn_train_cfg: dict, + concat: bool = True) -> Tuple[Tensor]: """Generate targets. Args: sampling_results (list[:obj:`SamplingResult`]): Sampled results from rois. rcnn_train_cfg (:obj:`ConfigDict`): Training config of rcnn. - concat (bool, optional): Whether to concatenate targets between + concat (bool): Whether to concatenate targets between batches. Defaults to True. Returns: @@ -393,7 +401,8 @@ def get_targets(self, sampling_results, rcnn_train_cfg, concat=True): return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, bbox_weights) - def _get_target_single(self, pos_bboxes, pos_gt_bboxes, ious, cfg): + def _get_target_single(self, pos_bboxes: Tensor, pos_gt_bboxes: Tensor, + ious: Tensor, cfg: dict) -> Tuple[Tensor]: """Generate training targets for a single sample. Args: @@ -527,12 +536,12 @@ def get_results(self, return result_list def multi_class_nms(self, - box_probs, - box_preds, - score_thr, - nms_thr, - input_meta, - use_rotate_nms=True): + box_probs: Tensor, + box_preds: Tensor, + score_thr: float, + nms_thr: float, + input_meta: dict, + use_rotate_nms: bool = True) -> Tensor: """Multi-class NMS for box head. Note: @@ -547,7 +556,7 @@ def multi_class_nms(self, score_thr (float): Threshold of scores. nms_thr (float): Threshold for NMS. input_meta (dict): Meta information of the current sample. - use_rotate_nms (bool, optional): Whether to use rotated nms. + use_rotate_nms (bool): Whether to use rotated nms. Defaults to True. Returns: diff --git a/mmdet3d/models/roi_heads/mask_heads/pointwise_semantic_head.py b/mmdet3d/models/roi_heads/mask_heads/pointwise_semantic_head.py index 73622202df..70175a879b 100644 --- a/mmdet3d/models/roi_heads/mask_heads/pointwise_semantic_head.py +++ b/mmdet3d/models/roi_heads/mask_heads/pointwise_semantic_head.py @@ -1,11 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, Tuple + import torch from mmengine.model import BaseModule +from torch import Tensor from torch import nn as nn from torch.nn import functional as F from mmdet3d.registry import MODELS -from mmdet3d.structures.bbox_3d import rotation_3d_in_axis +from mmdet3d.structures.bbox_3d import BaseInstance3DBoxes, rotation_3d_in_axis from mmdet3d.utils import InstanceList from mmdet.models.utils import multi_apply @@ -25,23 +28,23 @@ class PointwiseSemanticHead(BaseModule): loss_part (dict): Config of part prediction loss. """ - def __init__(self, - in_channels, - num_classes=3, - extra_width=0.2, - seg_score_thr=0.3, - init_cfg=None, - loss_seg=dict( - type='FocalLoss', - use_sigmoid=True, - reduction='sum', - gamma=2.0, - alpha=0.25, - loss_weight=1.0), - loss_part=dict( - type='CrossEntropyLoss', - use_sigmoid=True, - loss_weight=1.0)): + def __init__( + self, + in_channels: int, + num_classes: int = 3, + extra_width: float = 0.2, + seg_score_thr: float = 0.3, + init_cfg: dict = None, + loss_seg: dict = dict( + type='FocalLoss', + use_sigmoid=True, + reduction='sum', + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + loss_part: dict = dict( + type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0) + ) -> None: super(PointwiseSemanticHead, self).__init__(init_cfg=init_cfg) self.extra_width = extra_width self.num_classes = num_classes @@ -52,7 +55,7 @@ def __init__(self, self.loss_seg = MODELS.build(loss_seg) self.loss_part = MODELS.build(loss_part) - def forward(self, x): + def forward(self, x: Tensor) -> Dict[str, Tensor]: """Forward pass. Args: @@ -78,7 +81,9 @@ def forward(self, x): return dict( seg_preds=seg_preds, part_preds=part_preds, part_feats=part_feats) - def get_targets_single(self, voxel_centers, gt_bboxes_3d, gt_labels_3d): + def get_targets_single(self, voxel_centers: Tensor, + gt_bboxes_3d: BaseInstance3DBoxes, + gt_labels_3d: Tensor) -> Tuple[Tensor]: """generate segmentation and part prediction targets for a single sample. @@ -161,7 +166,8 @@ def get_targets(self, voxel_dict: dict, part_targets = torch.cat(part_targets, dim=0) return dict(seg_targets=seg_targets, part_targets=part_targets) - def loss(self, semantic_results, semantic_targets): + def loss(self, semantic_results: dict, + semantic_targets: dict) -> Dict[str, Tensor]: """Calculate point-wise segmentation and part prediction losses. Args: diff --git a/mmdet3d/models/roi_heads/mask_heads/primitive_head.py b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py index 447e4404e0..317c38dbe7 100644 --- a/mmdet3d/models/roi_heads/mask_heads/primitive_head.py +++ b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Tuple import torch from mmcv.cnn import ConvModule @@ -12,6 +12,7 @@ from mmdet3d.models.layers import VoteModule, build_sa_module from mmdet3d.registry import MODELS from mmdet3d.structures import Det3DDataSample +from mmdet3d.structures.bbox_3d import BaseInstance3DBoxes from mmdet.models.utils import multi_apply @@ -126,7 +127,7 @@ def sample_mode(self): assert sample_mode in ['vote', 'seed', 'random'] return sample_mode - def forward(self, feats_dict): + def forward(self, feats_dict: dict) -> dict: """Forward pass. Args: @@ -392,12 +393,13 @@ def get_targets( return (point_mask, point_offset, gt_primitive_center, gt_primitive_semantic, gt_sem_cls_label, gt_votes_mask) - def get_targets_single(self, - points, - gt_bboxes_3d, - gt_labels_3d, - pts_semantic_mask=None, - pts_instance_mask=None): + def get_targets_single( + self, + points: torch.Tensor, + gt_bboxes_3d: BaseInstance3DBoxes, + gt_labels_3d: torch.Tensor, + pts_semantic_mask: torch.Tensor = None, + pts_instance_mask: torch.Tensor = None) -> Tuple[torch.Tensor]: """Generate targets of primitive head for single batch. Args: @@ -668,7 +670,8 @@ def get_targets_single(self, return (point_mask, point_sem, point_offset) - def primitive_decode_scores(self, predictions, aggregated_points): + def primitive_decode_scores(self, predictions: torch.Tensor, + aggregated_points: torch.Tensor) -> dict: """Decode predicted parts to primitive head. Args: @@ -696,7 +699,7 @@ def primitive_decode_scores(self, predictions, aggregated_points): return ret_dict - def check_horizon(self, points): + def check_horizon(self, points: torch.Tensor) -> bool: """Check whether is a horizontal plane. Args: @@ -709,7 +712,8 @@ def check_horizon(self, points): (points[1][-1] == points[2][-1]) and \ (points[2][-1] == points[3][-1]) - def check_dist(self, plane_equ, points): + def check_dist(self, plane_equ: torch.Tensor, + points: torch.Tensor) -> tuple: """Whether the mean of points to plane distance is lower than thresh. Args: @@ -722,7 +726,8 @@ def check_dist(self, plane_equ, points): return (points[:, 2] + plane_equ[-1]).sum() / 4.0 < self.train_cfg['lower_thresh'] - def point2line_dist(self, points, pts_a, pts_b): + def point2line_dist(self, points: torch.Tensor, pts_a: torch.Tensor, + pts_b: torch.Tensor) -> torch.Tensor: """Calculate the distance from point to line. Args: @@ -741,7 +746,11 @@ def point2line_dist(self, points, pts_a, pts_b): return dist - def match_point2line(self, points, corners, with_yaw, mode='bottom'): + def match_point2line(self, + points: torch.Tensor, + corners: torch.Tensor, + with_yaw: bool, + mode: str = 'bottom') -> tuple: """Match points to corresponding line. Args: @@ -782,7 +791,8 @@ def match_point2line(self, points, corners, with_yaw, mode='bottom'): selected_list = [sel1, sel2, sel3, sel4] return selected_list - def match_point2plane(self, plane, points): + def match_point2plane(self, plane: torch.Tensor, + points: torch.Tensor) -> tuple: """Match points to plane. Args: @@ -800,10 +810,14 @@ def match_point2plane(self, plane, points): min_dist) < self.train_cfg['dist_thresh'] return point2plane_dist, selected - def compute_primitive_loss(self, primitive_center, primitive_semantic, - semantic_scores, num_proposal, - gt_primitive_center, gt_primitive_semantic, - gt_sem_cls_label, gt_primitive_mask): + def compute_primitive_loss(self, primitive_center: torch.Tensor, + primitive_semantic: torch.Tensor, + semantic_scores: torch.Tensor, + num_proposal: torch.Tensor, + gt_primitive_center: torch.Tensor, + gt_primitive_semantic: torch.Tensor, + gt_sem_cls_label: torch.Tensor, + gt_primitive_mask: torch.Tensor) -> Tuple: """Compute loss of primitive module. Args: @@ -849,7 +863,8 @@ def compute_primitive_loss(self, primitive_center, primitive_semantic, return center_loss, size_loss, sem_cls_loss - def get_primitive_center(self, pred_flag, center): + def get_primitive_center(self, pred_flag: torch.Tensor, + center: torch.Tensor) -> Tuple: """Generate primitive center from predictions. Args: @@ -869,17 +884,17 @@ def get_primitive_center(self, pred_flag, center): return center, pred_indices def _assign_primitive_line_targets(self, - point_mask, - point_offset, - point_sem, - coords, - indices, - cls_label, - point2line_matching, - corners, - center_axises, - with_yaw, - mode='bottom'): + point_mask: torch.Tensor, + point_offset: torch.Tensor, + point_sem: torch.Tensor, + coords: torch.Tensor, + indices: torch.Tensor, + cls_label: int, + point2line_matching: torch.Tensor, + corners: torch.Tensor, + center_axises: torch.Tensor, + with_yaw: bool, + mode: str = 'bottom') -> Tuple: """Generate targets of line primitive. Args: @@ -934,15 +949,15 @@ def _assign_primitive_line_targets(self, return point_mask, point_offset, point_sem def _assign_primitive_surface_targets(self, - point_mask, - point_offset, - point_sem, - coords, - indices, - cls_label, - corners, - with_yaw, - mode='bottom'): + point_mask: torch.Tensor, + point_offset: torch.Tensor, + point_sem: torch.Tensor, + coords: torch.Tensor, + indices: torch.Tensor, + cls_label: int, + corners: torch.Tensor, + with_yaw: bool, + mode: str = 'bottom') -> Tuple: """Generate targets for primitive z and primitive xy. Args: @@ -1017,7 +1032,9 @@ def _assign_primitive_surface_targets(self, point_offset[indices] = center - coords return point_mask, point_offset, point_sem - def _get_plane_fomulation(self, vector1, vector2, point): + def _get_plane_fomulation(self, vector1: torch.Tensor, + vector2: torch.Tensor, + point: torch.Tensor) -> torch.Tensor: """Compute the equation of the plane. Args: diff --git a/mmdet3d/models/roi_heads/part_aggregation_roi_head.py b/mmdet3d/models/roi_heads/part_aggregation_roi_head.py index 92544fc149..d7605d3c5a 100644 --- a/mmdet3d/models/roi_heads/part_aggregation_roi_head.py +++ b/mmdet3d/models/roi_heads/part_aggregation_roi_head.py @@ -89,17 +89,19 @@ def _bbox_forward_train(self, feats_dict: Dict, voxels_dict: Dict, bbox_results.update(loss_bbox=loss_bbox) return bbox_results - def _assign_and_sample(self, proposal_list: InstanceList, - batch_gt_instances_3d: InstanceList, - batch_gt_instances_ignore) -> List[SamplingResult]: + def _assign_and_sample( + self, rpn_results_list: InstanceList, + batch_gt_instances_3d: InstanceList, + batch_gt_instances_ignore: InstanceList) -> List[SamplingResult]: """Assign and sample proposals for training. Args: - proposal_list (list[:obj:`InstancesData`]): Proposals produced by - rpn head. + rpn_results_list (List[:obj:`InstanceData`]): Detection results + of rpn head. batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of gt_instances. It usually includes ``bboxes_3d`` and ``labels_3d`` attributes. + batch_gt_instances_ignore (list): Ignore instances of gt bboxes. Returns: list[:obj:`SamplingResult`]: Sampled results of each training @@ -107,8 +109,8 @@ def _assign_and_sample(self, proposal_list: InstanceList, """ sampling_results = [] # bbox assign - for batch_idx in range(len(proposal_list)): - cur_proposal_list = proposal_list[batch_idx] + for batch_idx in range(len(rpn_results_list)): + cur_proposal_list = rpn_results_list[batch_idx] cur_boxes = cur_proposal_list['bboxes_3d'] cur_labels_3d = cur_proposal_list['labels_3d'] cur_gt_instances_3d = batch_gt_instances_3d[batch_idx] @@ -202,7 +204,7 @@ def predict(self, Args: feats_dict (dict): Contains features from the first stage. - rpn_results_list (List[:obj:`InstancesData`]): Detection results + rpn_results_list (List[:obj:`InstanceData`]): Detection results of rpn head. batch_data_samples (List[:obj:`Det3DDataSample`]): The Data samples. It usually includes information such as @@ -249,7 +251,7 @@ def predict_bbox(self, feats_dict: Dict, voxel_dict: Dict, voxel_dict (dict): Contains information of voxels. batch_input_metas (list[dict], Optional): Batch image meta info. Defaults to None. - rpn_results_list (List[:obj:`InstancesData`]): Detection results + rpn_results_list (List[:obj:`InstanceData`]): Detection results of rpn head. test_cfg (Config): Test config. @@ -318,7 +320,7 @@ def loss(self, feats_dict: Dict, rpn_results_list: InstanceList, Args: feats_dict (dict): Contains features from the first stage. - rpn_results_list (List[:obj:`InstancesData`]): Detection results + rpn_results_list (List[:obj:`InstanceData`]): Detection results of rpn head. batch_data_samples (List[:obj:`Det3DDataSample`]): The Data samples. It usually includes information such as @@ -361,7 +363,7 @@ def _forward(self, feats_dict: dict, Args: feats_dict (dict): Contains features from the first stage. - rpn_results_list (List[:obj:`InstancesData`]): Detection results + rpn_results_list (List[:obj:`InstanceData`]): Detection results of rpn head. Returns: diff --git a/mmdet3d/models/roi_heads/point_rcnn_roi_head.py b/mmdet3d/models/roi_heads/point_rcnn_roi_head.py index 38864c3d8d..e7942fc426 100644 --- a/mmdet3d/models/roi_heads/point_rcnn_roi_head.py +++ b/mmdet3d/models/roi_heads/point_rcnn_roi_head.py @@ -2,6 +2,7 @@ from typing import Dict import torch +from torch import Tensor from torch.nn import functional as F from mmdet3d.registry import MODELS, TASK_UTILS @@ -26,12 +27,12 @@ class PointRCNNRoIHead(Base3DRoIHead): """ def __init__(self, - bbox_head, - bbox_roi_extractor, - train_cfg, - test_cfg, - depth_normalizer=70.0, - init_cfg=None): + bbox_head: dict, + bbox_roi_extractor: dict, + train_cfg: dict, + test_cfg: dict, + depth_normalizer: dict = 70.0, + init_cfg: dict = None) -> None: super(PointRCNNRoIHead, self).__init__( bbox_head=bbox_head, bbox_roi_extractor=bbox_roi_extractor, @@ -66,7 +67,7 @@ def loss(self, feats_dict: Dict, rpn_results_list: InstanceList, Args: feats_dict (dict): Contains features from the first stage. - rpn_results_list (List[:obj:`InstancesData`]): Detection results + rpn_results_list (List[:obj:`InstanceData`]): Detection results of rpn head. batch_data_samples (List[:obj:`Det3DDataSample`]): The Data samples. It usually includes information such as @@ -75,8 +76,8 @@ def loss(self, feats_dict: Dict, rpn_results_list: InstanceList, Returns: dict[str, Tensor]: A dictionary of loss components """ - features = feats_dict['features'] - points = feats_dict['points'] + features = feats_dict['fp_features'] + fp_points = feats_dict['fp_points'] point_cls_preds = feats_dict['points_cls_preds'] sem_scores = point_cls_preds.sigmoid() point_scores = sem_scores.max(-1)[0] @@ -94,14 +95,14 @@ def loss(self, feats_dict: Dict, rpn_results_list: InstanceList, # concat the depth, semantic features and backbone features features = features.transpose(1, 2).contiguous() - point_depths = points.norm(dim=2) / self.depth_normalizer - 0.5 + point_depths = fp_points.norm(dim=2) / self.depth_normalizer - 0.5 features_list = [ point_scores.unsqueeze(2), point_depths.unsqueeze(2), features ] features = torch.cat(features_list, dim=2) - bbox_results = self._bbox_forward_train(features, points, + bbox_results = self._bbox_forward_train(features, fp_points, sample_results) losses = dict() losses.update(bbox_results['loss_bbox']) @@ -119,7 +120,7 @@ def predict(self, Args: feats_dict (dict): Contains features from the first stage. - rpn_results_list (List[:obj:`InstancesData`]): Detection results + rpn_results_list (List[:obj:`InstanceData`]): Detection results of rpn head. batch_data_samples (List[:obj:`Det3DDataSample`]): The Data samples. It usually includes information such as @@ -146,14 +147,14 @@ def predict(self, batch_input_metas = [ data_samples.metainfo for data_samples in batch_data_samples ] - features = feats_dict['features'] - points = feats_dict['points'] + fp_features = feats_dict['fp_features'] + fp_points = feats_dict['fp_points'] point_cls_preds = feats_dict['points_cls_preds'] sem_scores = point_cls_preds.sigmoid() point_scores = sem_scores.max(-1)[0] - features = features.transpose(1, 2).contiguous() - point_depths = points.norm(dim=2) / self.depth_normalizer - 0.5 + features = fp_features.transpose(1, 2).contiguous() + point_depths = fp_points.norm(dim=2) / self.depth_normalizer - 0.5 features_list = [ point_scores.unsqueeze(2), point_depths.unsqueeze(2), features @@ -161,7 +162,8 @@ def predict(self, features = torch.cat(features_list, dim=2) batch_size = features.shape[0] - bbox_results = self._bbox_forward(features, points, batch_size, rois) + bbox_results = self._bbox_forward(features, fp_points, batch_size, + rois) object_score = bbox_results['cls_score'].sigmoid() bbox_list = self.bbox_head.get_results( rois, @@ -173,7 +175,8 @@ def predict(self, return bbox_list - def _bbox_forward_train(self, features, points, sampling_results): + def _bbox_forward_train(self, features: Tensor, points: Tensor, + sampling_results: SampleList) -> dict: """Forward training function of roi_extractor and bbox_head. Args: @@ -199,7 +202,8 @@ def _bbox_forward_train(self, features, points, sampling_results): bbox_results.update(loss_bbox=loss_bbox) return bbox_results - def _bbox_forward(self, features, points, batch_size, rois): + def _bbox_forward(self, features: Tensor, points: Tensor, batch_size: int, + rois: Tensor) -> dict: """Forward function of roi_extractor and bbox_head used in both training and testing. @@ -221,15 +225,20 @@ def _bbox_forward(self, features, points, batch_size, rois): bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred) return bbox_results - def _assign_and_sample(self, rpn_results_list, batch_gt_instances_3d, - batch_gt_instances_ignore): + def _assign_and_sample( + self, rpn_results_list: InstanceList, + batch_gt_instances_3d: InstanceList, + batch_gt_instances_ignore: InstanceList) -> SampleList: """Assign and sample proposals for training. Args: - proposal_list (list[dict]): Proposals produced by RPN. - gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth - boxes. - gt_labels_3d (list[torch.Tensor]): Ground truth labels + rpn_results_list (List[:obj:`InstanceData`]): Detection results + of rpn head. + batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of + gt_instances. It usually includes ``bboxes_3d`` and + ``labels_3d`` attributes. + batch_gt_instances_ignore (list[:obj:`InstanceData`]): Ignore + instances of gt bboxes. Returns: list[:obj:`SamplingResult`]: Sampled results of each training diff --git a/mmdet3d/models/roi_heads/roi_extractors/single_roiaware_extractor.py b/mmdet3d/models/roi_heads/roi_extractors/single_roiaware_extractor.py index 35acc16797..a9ad793907 100644 --- a/mmdet3d/models/roi_heads/roi_extractors/single_roiaware_extractor.py +++ b/mmdet3d/models/roi_heads/roi_extractors/single_roiaware_extractor.py @@ -1,7 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +import torch.nn as nn from mmcv import ops from mmengine.model import BaseModule +from torch import Tensor from mmdet3d.registry import MODELS @@ -16,11 +18,11 @@ class Single3DRoIAwareExtractor(BaseModule): roi_layer (dict): The config of roi layer. """ - def __init__(self, roi_layer=None, init_cfg=None): + def __init__(self, roi_layer: dict = None, init_cfg: dict = None) -> None: super(Single3DRoIAwareExtractor, self).__init__(init_cfg=init_cfg) self.roi_layer = self.build_roi_layers(roi_layer) - def build_roi_layers(self, layer_cfg): + def build_roi_layers(self, layer_cfg: dict) -> nn.Module: """Build roi layers using `layer_cfg`""" cfg = layer_cfg.copy() layer_type = cfg.pop('type') @@ -29,7 +31,8 @@ def build_roi_layers(self, layer_cfg): roi_layers = layer_cls(**cfg) return roi_layers - def forward(self, feats, coordinate, batch_inds, rois): + def forward(self, feats: Tensor, coordinate: Tensor, batch_inds: Tensor, + rois: Tensor) -> Tensor: """Extract point-wise roi features. Args: diff --git a/mmdet3d/models/roi_heads/roi_extractors/single_roipoint_extractor.py b/mmdet3d/models/roi_heads/roi_extractors/single_roipoint_extractor.py index df975f5dfb..043ee1be95 100644 --- a/mmdet3d/models/roi_heads/roi_extractors/single_roipoint_extractor.py +++ b/mmdet3d/models/roi_heads/roi_extractors/single_roipoint_extractor.py @@ -1,7 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import torch +import torch.nn as nn from mmcv import ops -from torch import nn as nn +from torch import Tensor from mmdet3d.registry import MODELS from mmdet3d.structures.bbox_3d import rotation_3d_in_axis @@ -17,11 +18,11 @@ class Single3DRoIPointExtractor(nn.Module): roi_layer (dict): The config of roi layer. """ - def __init__(self, roi_layer=None): + def __init__(self, roi_layer: dict = None) -> None: super(Single3DRoIPointExtractor, self).__init__() self.roi_layer = self.build_roi_layers(roi_layer) - def build_roi_layers(self, layer_cfg): + def build_roi_layers(self, layer_cfg: dict) -> nn.Module: """Build roi layers using `layer_cfg`""" cfg = layer_cfg.copy() layer_type = cfg.pop('type') @@ -30,7 +31,8 @@ def build_roi_layers(self, layer_cfg): roi_layers = layer_cls(**cfg) return roi_layers - def forward(self, feats, coordinate, batch_inds, rois): + def forward(self, feats: Tensor, coordinate: Tensor, batch_inds: Tensor, + rois: Tensor) -> Tensor: """Extract point-wise roi features. Args: From 4c08d134064cc42e5c7ba5850fc453440013fc5f Mon Sep 17 00:00:00 2001 From: VVsssssk Date: Mon, 19 Sep 2022 17:31:06 +0800 Subject: [PATCH 11/16] fix bug --- mmdet3d/models/dense_heads/point_rpn_head.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mmdet3d/models/dense_heads/point_rpn_head.py b/mmdet3d/models/dense_heads/point_rpn_head.py index 97a7a030c7..ac7067645e 100644 --- a/mmdet3d/models/dense_heads/point_rpn_head.py +++ b/mmdet3d/models/dense_heads/point_rpn_head.py @@ -298,9 +298,12 @@ def predict_by_feat(self, points: Tensor, bbox_preds: List[Tensor], for b in range(batch_size): bbox3d = self.bbox_coder.decode(bbox_preds[b], points[b, ..., :3], object_class[b]) + mask = ~bbox3d.sum(dim=1).isinf() bbox_selected, score_selected, labels, cls_preds_selected = \ - self.class_agnostic_nms(obj_scores[b], sem_scores[b], bbox3d, - points[b, ..., :3], + self.class_agnostic_nms(obj_scores[b][mask], + sem_scores[b][mask, :], + bbox3d[mask, :], + points[b, ..., :3][mask, :], batch_input_metas[b], cfg.nms_cfg) bbox_selected = batch_input_metas[b]['box_type_3d']( @@ -506,4 +509,6 @@ def loss_and_predict(self, batch_input_metas=batch_input_metas, cfg=proposal_cfg) feats_dict['points_cls_preds'] = cls_preds + if predictions[0].bboxes_3d.tensor.isinf().any(): + print(predictions) return losses, predictions From 861c221f3005a1bef860fc7e2083aea6195bced2 Mon Sep 17 00:00:00 2001 From: VVsssssk Date: Mon, 26 Sep 2022 16:45:24 +0800 Subject: [PATCH 12/16] fix comments --- mmdet3d/datasets/transforms/loading.py | 5 +---- tools/model_converters/pointrcnn_convert.py | 25 --------------------- 2 files changed, 1 insertion(+), 29 deletions(-) delete mode 100644 tools/model_converters/pointrcnn_convert.py diff --git a/mmdet3d/datasets/transforms/loading.py b/mmdet3d/datasets/transforms/loading.py index 0a60bd111e..615c1d74f3 100644 --- a/mmdet3d/datasets/transforms/loading.py +++ b/mmdet3d/datasets/transforms/loading.py @@ -424,10 +424,7 @@ def __init__( self.load_dim = load_dim self.use_dim = use_dim self.file_client_args = file_client_args.copy() - if self.file_client_args is not None: - self.file_client = mmengine.FileClient(**self.file_client_args) - else: - self.file_client = None + self.file_client = None def _load_points(self, pts_filename: str) -> np.ndarray: """Private function to load point clouds data. diff --git a/tools/model_converters/pointrcnn_convert.py b/tools/model_converters/pointrcnn_convert.py deleted file mode 100644 index 7edda4c4f5..0000000000 --- a/tools/model_converters/pointrcnn_convert.py +++ /dev/null @@ -1,25 +0,0 @@ -import torch -import mmengine - -mm_path = '/home/PJLAB/shenkun/openmmlab-refactor/mmdetection3d/checkpoints/point_rcnn_2x8_kitti-3d-3classes_20211208_151344.pth' -pc_path = '/home/PJLAB/shenkun/workspace/OpenPCDet/checkpoint/pointrcnn_7870.pth' - -def main(): - new_dict = dict() - ori = torch.load(mm_path) - mm_dict = torch.load(mm_path)['state_dict'] - pc_dict = torch.load(pc_path)['model_state'] - pc_dict.pop('global_step') - for i in range(len(mm_dict.keys())): - mm_name = list(mm_dict.keys())[i] - if 'backbone' in mm_name and 'conv' in mm_name and 'bias' in mm_name: - continue - else: - new_dict[mm_name] = mm_dict[mm_name] - for i in range(len(new_dict.keys())): - new_dict[list(new_dict.keys())[i]] = pc_dict[list(pc_dict.keys())[i]] - ori['state_dict'] = new_dict - torch.save(ori,'new_pointrcnn.pth') - -if __name__ == '__main__': - main() From bbee4073776a973face5a1881ace28ff5b18cece Mon Sep 17 00:00:00 2001 From: VVsssssk Date: Wed, 28 Sep 2022 13:02:13 +0800 Subject: [PATCH 13/16] fix comments --- mmdet3d/models/dense_heads/point_rpn_head.py | 54 ++++++++++--------- mmdet3d/models/detectors/point_rcnn.py | 12 ++--- .../bbox_heads/point_rcnn_bbox_head.py | 32 +++++------ .../mask_heads/pointwise_semantic_head.py | 4 +- .../roi_heads/mask_heads/primitive_head.py | 43 +++++++-------- .../models/roi_heads/point_rcnn_roi_head.py | 10 ++-- .../single_roiaware_extractor.py | 8 ++- .../single_roipoint_extractor.py | 6 ++- 8 files changed, 90 insertions(+), 79 deletions(-) diff --git a/mmdet3d/models/dense_heads/point_rpn_head.py b/mmdet3d/models/dense_heads/point_rpn_head.py index ac7067645e..9c2203aae9 100644 --- a/mmdet3d/models/dense_heads/point_rpn_head.py +++ b/mmdet3d/models/dense_heads/point_rpn_head.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import torch from mmengine.model import BaseModule @@ -40,15 +40,15 @@ class PointRPNHead(BaseModule): """ def __init__(self, - num_classes: dict, + num_classes: int, train_cfg: dict, test_cfg: dict, - pred_layer_cfg: dict = None, - enlarge_width: dict = 0.1, - cls_loss: dict = None, - bbox_loss: dict = None, - bbox_coder: dict = None, - init_cfg: dict = None) -> None: + pred_layer_cfg: Optional[dict] = None, + enlarge_width: float = 0.1, + cls_loss: Optional[dict] = None, + bbox_loss: Optional[dict] = None, + bbox_coder: Optional[dict] = None, + init_cfg: Optional[dict] = None) -> None: super().__init__(init_cfg=init_cfg) self.num_classes = num_classes self.train_cfg = train_cfg @@ -131,22 +131,24 @@ def forward(self, feat_dict: dict) -> Tuple[List[Tensor]]: batch_size, -1, self._get_reg_out_channels()) return point_box_preds, point_cls_preds - def loss_by_feat(self, - bbox_preds: List[Tensor], - cls_preds: List[Tensor], - points: List[Tensor], - batch_gt_instances_3d: InstanceList, - batch_input_metas: List[dict] = None, - batch_gt_instances_ignore: InstanceList = None) -> Dict: + def loss_by_feat( + self, + bbox_preds: List[Tensor], + cls_preds: List[Tensor], + points: List[Tensor], + batch_gt_instances_3d: InstanceList, + batch_input_metas: Optional[List[dict]] = None, + batch_gt_instances_ignore: Optional[InstanceList] = None) -> Dict: """Compute loss. Args: - bbox_preds (dict): Predictions from forward of PointRCNN RPN_Head. - cls_preds (dict): Classification from forward of PointRCNN - RPN_Head. + bbox_preds (list[torch.Tensor]): Predictions from forward of + PointRCNN RPN_Head. + cls_preds (list[torch.Tensor]): Classification from forward of + PointRCNN RPN_Head. points (list[torch.Tensor]): Input points. batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of - gt_instances. It usually includes ``bboxes_3d`` and + gt_instances_3d. It usually includes ``bboxes_3d`` and ``labels_3d`` attributes. batch_input_metas (list[dict]): Contain pcd and img's meta info. batch_gt_instances_ignore (list[:obj:`InstanceData`], optional): @@ -184,9 +186,9 @@ def get_targets(self, points: List[Tensor], """Generate targets of PointRCNN RPN head. Args: - points (list[torch.Tensor]): Points of each batch. + points (list[torch.Tensor]): Points in one batch. batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of - gt_instances. It usually includes ``bboxes_3d`` and + gt_instances_3d. It usually includes ``bboxes_3d`` and ``labels_3d`` attributes. Returns: @@ -264,13 +266,15 @@ def get_targets_single(self, points: Tensor, def predict_by_feat(self, points: Tensor, bbox_preds: List[Tensor], cls_preds: List[Tensor], batch_input_metas: List[dict], - cfg: Dict) -> InstanceList: + cfg: Optional[dict]) -> InstanceList: """Generate bboxes from RPN head predictions. Args: points (torch.Tensor): Input points. - bbox_preds (list): Regression predictions from PointRCNN head. - cls_preds (list): Class scores predictions from PointRCNN head. + bbox_preds (list[tensor]): Regression predictions from PointRCNN + head. + cls_preds (list[tensor]): Class scores predictions from PointRCNN + head. batch_input_metas (list[dict]): Batch inputs meta info. cfg (ConfigDict, optional): Test / postprocessing configuration. @@ -467,7 +471,7 @@ def predict(self, feats_dict: Dict, def loss_and_predict(self, feats_dict: Dict, batch_data_samples: SampleList, - proposal_cfg=None, + proposal_cfg: Optional[dict] = None, **kwargs) -> Tuple[dict, InstanceList]: """Perform forward propagation of the head, then calculate loss and predictions from the features and data samples. diff --git a/mmdet3d/models/detectors/point_rcnn.py b/mmdet3d/models/detectors/point_rcnn.py index 359ba0a067..acf6ab5031 100644 --- a/mmdet3d/models/detectors/point_rcnn.py +++ b/mmdet3d/models/detectors/point_rcnn.py @@ -26,12 +26,12 @@ class PointRCNN(TwoStage3DDetector): def __init__(self, backbone: dict, - neck: dict = None, - rpn_head: dict = None, - roi_head: dict = None, - train_cfg: dict = None, - test_cfg: dict = None, - init_cfg: dict = None, + neck: Optional[dict] = None, + rpn_head: Optional[dict] = None, + roi_head: Optional[dict] = None, + train_cfg: Optional[dict] = None, + test_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None, data_preprocessor: Optional[dict] = None) -> Optional: super(PointRCNN, self).__init__( backbone=backbone, diff --git a/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py index 43a288dd1c..1e7c53859d 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py +++ b/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Tuple +from typing import Dict, List, Optional, Tuple import numpy as np import torch @@ -29,17 +29,17 @@ class PointRCNNBboxHead(BaseModule): mlp_channels (list[int]): the number of mlp channels pred_layer_cfg (dict, optional): Config of classfication and regression prediction layers. Defaults to None. - num_points (tuple, optional): The number of points which each SA + num_points (tuple): The number of points which each SA module samples. Defaults to (128, 32, -1). - radius (tuple, optional): Sampling radius of each SA module. + radius (tuple): Sampling radius of each SA module. Defaults to (0.2, 0.4, 100). - num_samples (tuple, optional): The number of samples for ball query + num_samples (tuple): The number of samples for ball query in each SA module. Defaults to (64, 64, 64). - sa_channels (tuple, optional): Out channels of each mlp in SA module. + sa_channels (tuple): Out channels of each mlp in SA module. Defaults to ((128, 128, 128), (128, 128, 256), (256, 256, 512)). - bbox_coder (dict, optional): Config dict of box coders. + bbox_coder (dict): Config dict of box coders. Defaults to dict(type='DeltaXYZWLHRBBoxCoder'). - sa_cfg (dict, optional): Config of set abstraction module, which may + sa_cfg (dict): Config of set abstraction module, which may contain the following keys and values: - pool_mod (str): Pool method ('max' or 'avg') for SA modules. @@ -48,20 +48,20 @@ class PointRCNNBboxHead(BaseModule): each SA module. Defaults to dict(type='PointSAModule', pool_mod='max', use_xyz=True). - conv_cfg (dict, optional): Config dict of convolutional layers. + conv_cfg (dict): Config dict of convolutional layers. Defaults to dict(type='Conv1d'). - norm_cfg (dict, optional): Config dict of normalization layers. + norm_cfg (dict): Config dict of normalization layers. Defaults to dict(type='BN1d'). - act_cfg (dict, optional): Config dict of activation layers. + act_cfg (dict): Config dict of activation layers. Defaults to dict(type='ReLU'). - bias (str, optional): Type of bias. Defaults to 'auto'. - loss_bbox (dict, optional): Config of regression loss function. + bias (str): Type of bias. Defaults to 'auto'. + loss_bbox (dict): Config of regression loss function. Defaults to dict(type='SmoothL1Loss', beta=1.0 / 9.0, reduction='sum', loss_weight=1.0). - loss_cls (dict, optional): Config of classification loss function. + loss_cls (dict): Config of classification loss function. Defaults to dict(type='CrossEntropyLoss', use_sigmoid=True, reduction='sum', loss_weight=1.0). - with_corner_loss (bool, optional): Whether using corner loss. + with_corner_loss (bool): Whether using corner loss. Defaults to True. init_cfg (dict, optional): Config of initialization. Defaults to None. """ @@ -70,7 +70,7 @@ def __init__(self, num_classes: dict, in_channels: dict, mlp_channels: dict, - pred_layer_cfg: dict = None, + pred_layer_cfg: Optional[dict] = None, num_points: dict = (128, 32, -1), radius: dict = (0.2, 0.4, 100), num_samples: dict = (64, 64, 64), @@ -94,7 +94,7 @@ def __init__(self, reduction='sum', loss_weight=1.0), with_corner_loss: bool = True, - init_cfg: dict = None) -> None: + init_cfg: Optional[dict] = None) -> None: super(PointRCNNBboxHead, self).__init__(init_cfg=init_cfg) self.num_classes = num_classes self.num_sa = len(sa_channels) diff --git a/mmdet3d/models/roi_heads/mask_heads/pointwise_semantic_head.py b/mmdet3d/models/roi_heads/mask_heads/pointwise_semantic_head.py index 70175a879b..9f055dd066 100644 --- a/mmdet3d/models/roi_heads/mask_heads/pointwise_semantic_head.py +++ b/mmdet3d/models/roi_heads/mask_heads/pointwise_semantic_head.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, Tuple +from typing import Dict, Optional, Tuple import torch from mmengine.model import BaseModule @@ -34,7 +34,7 @@ def __init__( num_classes: int = 3, extra_width: float = 0.2, seg_score_thr: float = 0.3, - init_cfg: dict = None, + init_cfg: Optional[dict] = None, loss_seg: dict = dict( type='FocalLoss', use_sigmoid=True, diff --git a/mmdet3d/models/roi_heads/mask_heads/primitive_head.py b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py index 317c38dbe7..22f3892227 100644 --- a/mmdet3d/models/roi_heads/mask_heads/primitive_head.py +++ b/mmdet3d/models/roi_heads/mask_heads/primitive_head.py @@ -27,39 +27,42 @@ class PrimitiveHead(BaseModule): available mode ['z', 'xy', 'line']. bbox_coder (:obj:`BaseBBoxCoder`): Bbox coder for encoding and decoding boxes. - train_cfg (dict): Config for training. - test_cfg (dict): Config for testing. - vote_module_cfg (dict): Config of VoteModule for point-wise votes. - vote_aggregation_cfg (dict): Config of vote aggregation layer. + train_cfg (dict, optional): Config for training. + test_cfg (dict, optional): Config for testing. + vote_module_cfg (dict, optional): Config of VoteModule for point-wise + votes. + vote_aggregation_cfg (dict, optional): Config of vote aggregation + layer. feat_channels (tuple[int]): Convolution channels of prediction layer. upper_thresh (float): Threshold for line matching. surface_thresh (float): Threshold for surface matching. - conv_cfg (dict): Config of convolution in prediction layer. - norm_cfg (dict): Config of BN in prediction layer. - objectness_loss (dict): Config of objectness loss. - center_loss (dict): Config of center loss. - semantic_loss (dict): Config of point-wise semantic segmentation loss. + conv_cfg (dict, optional): Config of convolution in prediction layer. + norm_cfg (dict, optional): Config of BN in prediction layer. + objectness_loss (dict, optional): Config of objectness loss. + center_loss (dict, optional): Config of center loss. + semantic_loss (dict, optional): Config of point-wise semantic + segmentation loss. """ def __init__(self, num_dims: int, num_classes: int, primitive_mode: str, - train_cfg: dict = None, - test_cfg: dict = None, - vote_module_cfg: dict = None, - vote_aggregation_cfg: dict = None, + train_cfg: Optional[dict] = None, + test_cfg: Optional[dict] = None, + vote_module_cfg: Optional[dict] = None, + vote_aggregation_cfg: Optional[dict] = None, feat_channels: tuple = (128, 128), upper_thresh: float = 100.0, surface_thresh: float = 0.5, conv_cfg: dict = dict(type='Conv1d'), norm_cfg: dict = dict(type='BN1d'), - objectness_loss: dict = None, - center_loss: dict = None, - semantic_reg_loss: dict = None, - semantic_cls_loss: dict = None, - init_cfg: dict = None): + objectness_loss: Optional[dict] = None, + center_loss: Optional[dict] = None, + semantic_reg_loss: Optional[dict] = None, + semantic_cls_loss: Optional[dict] = None, + init_cfg: Optional[dict] = None): super(PrimitiveHead, self).__init__(init_cfg=init_cfg) # bounding boxes centers, face centers and edge centers assert primitive_mode in ['z', 'xy', 'line'] @@ -256,10 +259,8 @@ def loss_by_feat( attributes. batch_pts_semantic_mask (list[tensor]): Semantic mask of points cloud. Defaults to None. - batch_pts_semantic_mask (list[tensor]): Instance mask + batch_pts_instance_mask (list[tensor]): Instance mask of points cloud. Defaults to None. - batch_input_metas (list[dict]): Contain pcd and img's meta info. - ret_target (bool): Return targets or not. Defaults to False. Returns: dict: Losses of Primitive Head. diff --git a/mmdet3d/models/roi_heads/point_rcnn_roi_head.py b/mmdet3d/models/roi_heads/point_rcnn_roi_head.py index e7942fc426..4ebdcfa93f 100644 --- a/mmdet3d/models/roi_heads/point_rcnn_roi_head.py +++ b/mmdet3d/models/roi_heads/point_rcnn_roi_head.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict +from typing import Dict, Optional import torch from torch import Tensor @@ -21,7 +21,7 @@ class PointRCNNRoIHead(Base3DRoIHead): bbox_roi_extractor (dict): Config of RoI extractor. train_cfg (dict): Train configs. test_cfg (dict): Test configs. - depth_normalizer (float, optional): Normalize depth feature. + depth_normalizer (float): Normalize depth feature. Defaults to 70.0. init_cfg (dict, optional): Config of initialization. Defaults to None. """ @@ -32,7 +32,7 @@ def __init__(self, train_cfg: dict, test_cfg: dict, depth_normalizer: dict = 70.0, - init_cfg: dict = None) -> None: + init_cfg: Optional[dict] = None) -> None: super(PointRCNNRoIHead, self).__init__( bbox_head=bbox_head, bbox_roi_extractor=bbox_roi_extractor, @@ -182,7 +182,7 @@ def _bbox_forward_train(self, features: Tensor, points: Tensor, Args: features (torch.Tensor): Backbone features with depth and \ semantic features. - points (torch.Tensor): Pointcloud. + points (torch.Tensor): Point cloud. sampling_results (:obj:`SamplingResult`): Sampled results used for training. @@ -210,7 +210,7 @@ def _bbox_forward(self, features: Tensor, points: Tensor, batch_size: int, Args: features (torch.Tensor): Backbone features with depth and semantic features. - points (torch.Tensor): Pointcloud. + points (torch.Tensor): Point cloud. batch_size (int): Batch size. rois (torch.Tensor): RoI boxes. diff --git a/mmdet3d/models/roi_heads/roi_extractors/single_roiaware_extractor.py b/mmdet3d/models/roi_heads/roi_extractors/single_roiaware_extractor.py index a9ad793907..00756cc3db 100644 --- a/mmdet3d/models/roi_heads/roi_extractors/single_roiaware_extractor.py +++ b/mmdet3d/models/roi_heads/roi_extractors/single_roiaware_extractor.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + import torch import torch.nn as nn from mmcv import ops @@ -15,10 +17,12 @@ class Single3DRoIAwareExtractor(BaseModule): Extract Point-wise roi features. Args: - roi_layer (dict): The config of roi layer. + roi_layer (dict, optional): The config of roi layer. """ - def __init__(self, roi_layer: dict = None, init_cfg: dict = None) -> None: + def __init__(self, + roi_layer: Optional[dict] = None, + init_cfg: Optional[dict] = None) -> None: super(Single3DRoIAwareExtractor, self).__init__(init_cfg=init_cfg) self.roi_layer = self.build_roi_layers(roi_layer) diff --git a/mmdet3d/models/roi_heads/roi_extractors/single_roipoint_extractor.py b/mmdet3d/models/roi_heads/roi_extractors/single_roipoint_extractor.py index 043ee1be95..2697d25e5e 100644 --- a/mmdet3d/models/roi_heads/roi_extractors/single_roipoint_extractor.py +++ b/mmdet3d/models/roi_heads/roi_extractors/single_roipoint_extractor.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +from typing import Optional + import torch import torch.nn as nn from mmcv import ops @@ -15,10 +17,10 @@ class Single3DRoIPointExtractor(nn.Module): Extract Point-wise roi features. Args: - roi_layer (dict): The config of roi layer. + roi_layer (dict, optional): The config of roi layer. """ - def __init__(self, roi_layer: dict = None) -> None: + def __init__(self, roi_layer: Optional[dict] = None) -> None: super(Single3DRoIPointExtractor, self).__init__() self.roi_layer = self.build_roi_layers(roi_layer) From 55c8713bb92f3291d1dbd3347159192c500629ef Mon Sep 17 00:00:00 2001 From: VVsssssk Date: Wed, 28 Sep 2022 13:05:05 +0800 Subject: [PATCH 14/16] fix --- mmdet3d/models/dense_heads/point_rpn_head.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/mmdet3d/models/dense_heads/point_rpn_head.py b/mmdet3d/models/dense_heads/point_rpn_head.py index 9c2203aae9..e035461c81 100644 --- a/mmdet3d/models/dense_heads/point_rpn_head.py +++ b/mmdet3d/models/dense_heads/point_rpn_head.py @@ -200,13 +200,6 @@ def get_targets(self, points: List[Tensor], gt_bboxes_3d = [ instances.bboxes_3d for instances in batch_gt_instances_3d ] - # find empty example - for index in range(len(gt_labels_3d)): - if len(gt_labels_3d[index]) == 0: - fake_box = gt_bboxes_3d[index].tensor.new_zeros( - 1, gt_bboxes_3d[index].tensor.shape[-1]) - gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box) - gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1) (bbox_targets, mask_targets, positive_mask, negative_mask, point_targets) = multi_apply(self.get_targets_single, points, From 1b5ecce638270a79807b631285780c04d4268982 Mon Sep 17 00:00:00 2001 From: VVsssssk Date: Wed, 28 Sep 2022 13:06:32 +0800 Subject: [PATCH 15/16] fix --- mmdet3d/models/backbones/pointnet2_sa_msg.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmdet3d/models/backbones/pointnet2_sa_msg.py b/mmdet3d/models/backbones/pointnet2_sa_msg.py index 5e88ee01a2..18bfae7695 100644 --- a/mmdet3d/models/backbones/pointnet2_sa_msg.py +++ b/mmdet3d/models/backbones/pointnet2_sa_msg.py @@ -105,7 +105,7 @@ def __init__(self, dilated_group=dilated_group[sa_index], norm_cfg=norm_cfg, cfg=sa_cfg, - bias=False)) + bias=True)) skip_channel_list.append(sa_out_channel) cur_aggregation_channel = aggregation_channels[sa_index] From 628d8d03d8fd779a8839b0b5c40c9c6d1f50252b Mon Sep 17 00:00:00 2001 From: Tai-Wang Date: Fri, 30 Sep 2022 16:40:14 +0800 Subject: [PATCH 16/16] Minor fix --- mmdet3d/models/dense_heads/point_rpn_head.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmdet3d/models/dense_heads/point_rpn_head.py b/mmdet3d/models/dense_heads/point_rpn_head.py index e035461c81..c65c7fff5a 100644 --- a/mmdet3d/models/dense_heads/point_rpn_head.py +++ b/mmdet3d/models/dense_heads/point_rpn_head.py @@ -143,7 +143,7 @@ def loss_by_feat( Args: bbox_preds (list[torch.Tensor]): Predictions from forward of - PointRCNN RPN_Head. + PointRCNN RPN_Head. cls_preds (list[torch.Tensor]): Classification from forward of PointRCNN RPN_Head. points (list[torch.Tensor]): Input points.