From 46f96545e2a7de2ab1514ef6a00d5cb3498eb9a8 Mon Sep 17 00:00:00 2001 From: Wenhao Wu Date: Wed, 1 Dec 2021 21:32:09 +0800 Subject: [PATCH 1/7] rebase & resubmit --- configs/_base_/models/point_rcnn.py | 130 ++++ configs/pointrcnn/README.md | 25 + .../pointrcnn_2x8_kitti-3d-3classes.py | 94 +++ mmdet3d/datasets/pipelines/transforms_3d.py | 10 +- mmdet3d/models/dense_heads/__init__.py | 3 +- mmdet3d/models/dense_heads/point_rpn_head.py | 332 ++++++++++ mmdet3d/models/dense_heads/ssd_3d_head.py | 2 +- mmdet3d/models/detectors/__init__.py | 3 +- mmdet3d/models/detectors/pointrcnn.py | 144 +++++ mmdet3d/models/necks/pointnet2_fp_neck.py | 2 +- mmdet3d/models/roi_heads/__init__.py | 3 +- .../models/roi_heads/bbox_heads/__init__.py | 3 +- .../roi_heads/bbox_heads/parta2_bbox_head.py | 27 +- .../bbox_heads/point_rcnn_bbox_head.py | 580 ++++++++++++++++++ .../models/roi_heads/point_rcnn_roi_head.py | 286 +++++++++ mmdet3d/ops/__init__.py | 3 +- tests/test_models/test_detectors.py | 29 + tests/test_models/test_heads/test_heads.py | 117 +++- tests/test_runtime/test_config.py | 27 + 19 files changed, 1791 insertions(+), 29 deletions(-) create mode 100644 configs/_base_/models/point_rcnn.py create mode 100644 configs/pointrcnn/README.md create mode 100644 configs/pointrcnn/pointrcnn_2x8_kitti-3d-3classes.py create mode 100644 mmdet3d/models/dense_heads/point_rpn_head.py create mode 100644 mmdet3d/models/detectors/pointrcnn.py create mode 100644 mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py create mode 100644 mmdet3d/models/roi_heads/point_rcnn_roi_head.py diff --git a/configs/_base_/models/point_rcnn.py b/configs/_base_/models/point_rcnn.py new file mode 100644 index 0000000000..e32c6f5084 --- /dev/null +++ b/configs/_base_/models/point_rcnn.py @@ -0,0 +1,130 @@ +model = dict( + type='PointRCNN', + backbone=dict( + type='PointNet2SAMSG', + in_channels=4, + num_points=(4096, 1024, 256, 64), + radii=((0.1, 0.5), (0.5, 1.0), (1.0, 2.0), (2.0, 4.0)), + num_samples=((16, 32), (16, 32), (16, 32), (16, 32)), + sa_channels=(((16, 16, 32), (32, 32, 64)), ((64, 64, 128), (64, 96, + 128)), + ((128, 196, 256), (128, 196, 256)), ((256, 256, 512), + (256, 384, 512))), + fps_mods=(('D-FPS'), ('D-FPS'), ('D-FPS'), ('D-FPS')), + fps_sample_range_lists=((-1), (-1), (-1), (-1)), + aggregation_channels=(None, None, None, None), + dilated_group=(False, False, False, False), + out_indices=(0, 1, 2, 3), + norm_cfg=dict(type='BN2d', eps=1e-3, momentum=0.1), + sa_cfg=dict( + type='PointSAModuleMSG', + pool_mod='max', + use_xyz=True, + normalize_xyz=False)), + neck=dict( + type='PointNetFPNeck', + fp_channels=((1536, 512, 512), (768, 512, 512), (608, 256, 256), + (257, 128, 128))), + rpn_head=dict( + type='PointRPNHead', + num_classes=3, + enlarge_width=0.1, + pred_layer_cfg=dict( + in_channels=128, + cls_linear_channels=(256, 256), + reg_linear_channels=(256, 256)), + cls_loss=dict( + type='FocalLoss', + use_sigmoid=True, + reduction='sum', + gamma=2.0, + alpha=0.25, + loss_weight=1.0), + bbox_loss=dict( + type='SmoothL1Loss', + beta=1.0 / 9.0, + reduction='sum', + loss_weight=1.0), + bbox_coder=dict( + type='PointXYZWHLRBBoxCoder', + code_size=8, + # code_size: (center residual (3), size regression (3), + # torch.cos(yaw) (1), torch.sin(yaw) (1) + use_mean_size=True, + mean_size=[[3.9, 1.6, 1.56], [0.8, 0.6, 1.73], [1.76, 0.6, + 1.73]])), + roi_head=dict( + type='PointRCNNRoIHead', + point_roi_extractor=dict( + type='Single3DRoIPointExtractor', + roi_layer=dict(type='RoIPointPool3d', num_sampled_points=512)), + bbox_head=dict( + type='PointRCNNBboxHead', + num_classes=1, + pred_layer_cfg=dict( + in_channels=512, + cls_conv_channels=(256, 256), + reg_conv_channels=(256, 256), + bias=True), + in_channels=5, + # 5 = 3 (xyz) + scores + depth + mlp_channels=[128, 128], + num_points=(128, 32, -1), + radius=(0.2, 0.4, 100), + num_samples=(16, 16, 16), + sa_channels=((128, 128, 128), (128, 128, 256), (256, 256, 512)), + with_corner_loss=True), + depth_normalizer=70.0), + # model training and testing settings + train_cfg=dict( + pos_distance_thr=10.0, + rpn=dict( + nms_cfg=dict( + use_rotate_nms=True, iou_thr=0.8, nms_pre=9000, nms_post=512), + score_thr=None), + rcnn=dict( + assigner=[ + dict( # for Car + type='MaxIoUAssigner', + iou_calculator=dict( + type='BboxOverlaps3D', coordinate='lidar'), + pos_iou_thr=0.55, + neg_iou_thr=0.55, + min_pos_iou=0.55, + ignore_iof_thr=-1), + dict( # for Pestrian + type='MaxIoUAssigner', + iou_calculator=dict( + type='BboxOverlaps3D', coordinate='lidar'), + pos_iou_thr=0.55, + neg_iou_thr=0.55, + min_pos_iou=0.55, + ignore_iof_thr=-1), + dict( # for Cyclist + type='MaxIoUAssigner', + iou_calculator=dict( + type='BboxOverlaps3D', coordinate='lidar'), + pos_iou_thr=0.55, + neg_iou_thr=0.55, + min_pos_iou=0.55, + ignore_iof_thr=-1) + ], + sampler=dict( + type='IoUNegPiecewiseSampler', + num=128, + pos_fraction=0.5, + neg_piece_fractions=[0.8, 0.2], + neg_iou_piece_thrs=[0.55, 0.1], + neg_pos_ub=-1, + add_gt_as_proposals=False, + return_iou=True), + cls_pos_thr=0.7, + cls_neg_thr=0.25)), + test_cfg=dict( + rpn=dict( + nms_cfg=dict( + use_rotate_nms=True, iou_thr=0.85, nms_pre=9000, nms_post=512), + score_thr=None), + rcnn=dict(use_rotate_nms=True, nms_thr=0.1, score_thr=0.1))) + +find_unused_parameters = True diff --git a/configs/pointrcnn/README.md b/configs/pointrcnn/README.md new file mode 100644 index 0000000000..d643d97c6f --- /dev/null +++ b/configs/pointrcnn/README.md @@ -0,0 +1,25 @@ +# PointRCNN: 3D Object Proposal Generation and Detection from Point Cloud + +## Introduction + + + +We implement PointRCNN and provide its results with checkpoints on KITTI dataset. + +``` +@InProceedings{Shi_2019_CVPR, + author = {Shi, Shaoshuai and Wang, Xiaogang and Li, Hongsheng}, + title = {PointRCNN: 3D Object Proposal Generation and Detection From Point Cloud}, + booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, + month = {June}, + year = {2019} +} +``` + +## Results + +### KITTI + +| Backbone |Class| Lr schd | Mem (GB) | Inf time (fps) | mAP | Download | +| :---------: | :-----: |:-----: | :------: | :------------: | :----: |:----: | +| [PointNet++](./pointrcnn_2x8_kitti-3d-3classes.py) |3 Class|cyclic 80e|7.1||70.39|| diff --git a/configs/pointrcnn/pointrcnn_2x8_kitti-3d-3classes.py b/configs/pointrcnn/pointrcnn_2x8_kitti-3d-3classes.py new file mode 100644 index 0000000000..1344aca5c5 --- /dev/null +++ b/configs/pointrcnn/pointrcnn_2x8_kitti-3d-3classes.py @@ -0,0 +1,94 @@ +_base_ = [ + '../_base_/datasets/kitti-3d-car.py', '../_base_/models/point_rcnn.py', + '../_base_/default_runtime.py', '../_base_/schedules/cyclic_40e.py' +] + +# dataset settings +dataset_type = 'KittiDataset' +data_root = 'data/kitti/' +class_names = ['Car', 'Pedestrian', 'Cyclist'] +point_cloud_range = [0, -40, -3, 70.4, 40, 1] +input_modality = dict(use_lidar=True, use_camera=False) + +db_sampler = dict( + data_root=data_root, + info_path=data_root + 'kitti_dbinfos_train.pkl', + rate=1.0, + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)), + sample_groups=dict(Car=20, Pedestrian=15, Cyclist=15), + classes=class_names) + +train_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), + dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectSample', db_sampler=db_sampler), + dict(type='RandomFlip3D', flip_ratio_bev_horizontal=0.5), + dict( + type='ObjectNoise', + num_try=100, + translation_std=[1.0, 1.0, 0.5], + global_rot_range=[0.0, 0.0], + rot_range=[-0.78539816, 0.78539816]), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.78539816, 0.78539816], + scale_ratio_range=[0.95, 1.05]), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='PointSample', num_points=16384, sample_range=40.0), + dict(type='PointShuffle'), + dict(type='DefaultFormatBundle3D', class_names=class_names), + dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] +test_pipeline = [ + dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D'), + dict( + type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='PointSample', num_points=16384, sample_range=40.0), + dict( + type='DefaultFormatBundle3D', + class_names=class_names, + with_label=False), + dict(type='Collect3D', keys=['points']) + ]) +] + +data = dict( + samples_per_gpu=2, + workers_per_gpu=2, + train=dict( + type='RepeatDataset', + times=2, + dataset=dict(pipeline=train_pipeline, classes=class_names)), + val=dict(pipeline=test_pipeline, classes=class_names), + test=dict(pipeline=test_pipeline, classes=class_names)) + +# optimizer +lr = 0.001 # max learning rate +optimizer = dict(lr=lr, betas=(0.95, 0.85)) +# runtime settings +runner = dict(type='EpochBasedRunner', max_epochs=80) +evaluation = dict(interval=2) +# yapf:disable +log_config = dict( + interval=30, + hooks=[ + dict(type='TextLoggerHook'), + dict(type='TensorboardLoggerHook') + ]) +# yapf:enable diff --git a/mmdet3d/datasets/pipelines/transforms_3d.py b/mmdet3d/datasets/pipelines/transforms_3d.py index 0cb95797f1..7045a8d039 100644 --- a/mmdet3d/datasets/pipelines/transforms_3d.py +++ b/mmdet3d/datasets/pipelines/transforms_3d.py @@ -892,8 +892,8 @@ def _points_random_sampling(self, if sample_range is not None and not replace: # Only sampling the near points when len(points) >= num_samples depth = np.linalg.norm(points.tensor, axis=1) - far_inds = np.where(depth > sample_range)[0] - near_inds = np.where(depth <= sample_range)[0] + far_inds = np.where(depth >= sample_range)[0] + near_inds = np.where(depth < sample_range)[0] # in case there are too many far points if len(far_inds) > num_samples: far_inds = np.random.choice( @@ -920,12 +920,6 @@ def __call__(self, results): and 'pts_semantic_mask' keys are updated in the result dict. """ points = results['points'] - # Points in Camera coord can provide the depth information. - # TODO: Need to support distance-based sampling for other coord system. - if self.sample_range is not None: - from mmdet3d.core.points import CameraPoints - assert isinstance(points, CameraPoints), 'Sampling based on' \ - 'distance is only applicable for CAMERA coord' points, choices = self._points_random_sampling( points, self.num_points, diff --git a/mmdet3d/models/dense_heads/__init__.py b/mmdet3d/models/dense_heads/__init__.py index 0a9e50d637..d182b8f03d 100644 --- a/mmdet3d/models/dense_heads/__init__.py +++ b/mmdet3d/models/dense_heads/__init__.py @@ -9,6 +9,7 @@ from .groupfree3d_head import GroupFree3DHead from .parta2_rpn_head import PartA2RPNHead from .pgd_head import PGDHead +from .point_rpn_head import PointRPNHead from .shape_aware_head import ShapeAwareHead from .smoke_mono3d_head import SMOKEMono3DHead from .ssd_3d_head import SSD3DHead @@ -18,5 +19,5 @@ 'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead', 'SSD3DHead', 'BaseConvBboxHead', 'CenterHead', 'ShapeAwareHead', 'BaseMono3DDenseHead', 'AnchorFreeMono3DHead', 'FCOSMono3DHead', - 'GroupFree3DHead', 'SMOKEMono3DHead', 'PGDHead' + 'GroupFree3DHead', 'PointRPNHead', 'SMOKEMono3DHead', 'PGDHead' ] diff --git a/mmdet3d/models/dense_heads/point_rpn_head.py b/mmdet3d/models/dense_heads/point_rpn_head.py new file mode 100644 index 0000000000..ac2548105b --- /dev/null +++ b/mmdet3d/models/dense_heads/point_rpn_head.py @@ -0,0 +1,332 @@ +import torch +from mmcv.runner import BaseModule, force_fp32 +from torch import nn as nn + +from mmdet3d.core.bbox.structures import (DepthInstance3DBoxes, + LiDARInstance3DBoxes) +from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu +from mmdet.core import build_bbox_coder, multi_apply +from mmdet.models import HEADS, build_loss + + +@HEADS.register_module() +class PointRPNHead(BaseModule): + """RPN module for PointRCNN. + + Args: + num_classes (int): Number of classes. + train_cfg (dict): Train configs. + test_cfg (dict): Test configs. + pred_layer_cfg (dict, optional): Config of classfication and + regression prediction layers. Defaults to None. + enlarge_width (float, optional): Enlarge bbox for each side to ignore + close points. Defaults to 0.1. + cls_loss (dict, optional): Config of direction classification loss. + Defaults to None. + bbox_loss (dict, optional): Config of localization loss. + Defaults to None. + bbox_coder (dict, optional): Config dict of box coders. + Defaults to None. + init_cfg (dict, optional): Config of initialization. Defaults to None. + """ + + def __init__(self, + num_classes, + train_cfg, + test_cfg, + pred_layer_cfg=None, + enlarge_width=0.1, + cls_loss=None, + bbox_loss=None, + bbox_coder=None, + init_cfg=None): + super().__init__(init_cfg=init_cfg) + self.num_classes = num_classes + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.enlarge_width = enlarge_width + + # build loss function + self.bbox_loss = build_loss(bbox_loss) + self.cls_loss = build_loss(cls_loss) + + # build box coder + self.bbox_coder = build_bbox_coder(bbox_coder) + + # build pred conv + self.cls_layers = self._make_fc_layers( + fc_cfg=pred_layer_cfg.cls_linear_channels, + input_channels=pred_layer_cfg.in_channels, + output_channels=self._get_cls_out_channels()) + + self.reg_layers = self._make_fc_layers( + fc_cfg=pred_layer_cfg.reg_linear_channels, + input_channels=pred_layer_cfg.in_channels, + output_channels=self._get_reg_out_channels()) + + def _make_fc_layers(self, fc_cfg, input_channels, output_channels): + """Make fully connect layers. + + Args: + fc_cfg (dict): Config of fully connect. + input_channels (int): Input channels for fc_layers. + output_channels (int): Input channels for fc_layers. + + Returns: + nn.Sequential: Fully connect layers. + """ + fc_layers = [] + c_in = input_channels + for k in range(0, fc_cfg.__len__()): + fc_layers.extend([ + nn.Linear(c_in, fc_cfg[k], bias=False), + nn.BatchNorm1d(fc_cfg[k]), + nn.ReLU(), + ]) + c_in = fc_cfg[k] + fc_layers.append(nn.Linear(c_in, output_channels, bias=True)) + return nn.Sequential(*fc_layers) + + def _get_cls_out_channels(self): + """Return the channel number of classification outputs.""" + # Class numbers (k) + objectness (1) + return self.num_classes + + def _get_reg_out_channels(self): + """Return the channel number of regression outputs.""" + # Bbox classification and regression + # (center residual (3), size regression (3) + # torch.cos(yaw) (1), torch.sin(yaw) (1) + return self.bbox_coder.code_size + + def forward(self, feat_dict): + point_features = feat_dict['fp_features'] + point_features = point_features.permute(0, 2, 1).contiguous() + bs = point_features.shape[0] + x_cls = point_features.view(-1, point_features.shape[-1]) + x_reg = point_features.view(-1, point_features.shape[-1]) + + point_cls_preds = self.cls_layers(x_cls).reshape( + bs, -1, self._get_cls_out_channels()) + point_box_preds = self.reg_layers(x_reg).reshape( + bs, -1, self._get_reg_out_channels()) + return (point_box_preds, point_cls_preds) + + @force_fp32(apply_to=('bbox_preds')) + def loss(self, + bbox_preds, + cls_preds, + points, + gt_bboxes_3d, + gt_labels_3d, + img_metas=None): + """Compute loss. + + Args: + bbox_preds (dict): Predictions from forward of PointRCNN RPN_Head. + cls_preds (dict): Classification from forward of PointRCNN + RPN_Head. + points (list[torch.Tensor]): Input points. + gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth + bboxes of each sample. + gt_labels_3d (list[torch.Tensor]): Labels of each sample. + img_metas (list[dict], Optional): Contain pcd and img's meta info. + Defaults to None. + + Returns: + dict: Losses of PointRCNN RPN module. + """ + targets = self.get_targets(points, gt_bboxes_3d, gt_labels_3d) + (bbox_targets, mask_targets, positive_mask, negative_mask, + box_loss_weights, point_targets) = targets + + # bbox loss + bbox_loss = self.bbox_loss(bbox_preds, bbox_targets, + box_loss_weights.unsqueeze(-1)) + # calculate semantic loss + semantic_points = cls_preds.reshape(-1, self.num_classes) + semantic_targets = mask_targets + semantic_targets[negative_mask] = self.num_classes + semantic_points_label = semantic_targets + # for ignore, but now we do not have ignore label + semantic_loss_weight = negative_mask.float() + positive_mask.float() + semantic_loss = self.cls_loss(semantic_points, + semantic_points_label.reshape(-1), + semantic_loss_weight.reshape(-1)) + semantic_loss /= positive_mask.float().sum() + losses = dict(bbox_loss=bbox_loss, semantic_loss=semantic_loss) + + return losses + + def get_targets(self, points, gt_bboxes_3d, gt_labels_3d): + """Generate targets of PointRCNN RPN head. + + Args: + points (list[torch.Tensor]): Points of each batch. + gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth + bboxes of each batch. + gt_labels_3d (list[torch.Tensor]): Labels of each batch. + + Returns: + tuple[torch.Tensor]: Targets of PointRCNN RPN head. + """ + # find empty example + for index in range(len(gt_labels_3d)): + if len(gt_labels_3d[index]) == 0: + fake_box = gt_bboxes_3d[index].tensor.new_zeros( + 1, gt_bboxes_3d[index].tensor.shape[-1]) + gt_bboxes_3d[index] = gt_bboxes_3d[index].new_box(fake_box) + gt_labels_3d[index] = gt_labels_3d[index].new_zeros(1) + + (bbox_targets, mask_targets, positive_mask, negative_mask, + point_targets) = multi_apply(self.get_targets_single, points, + gt_bboxes_3d, gt_labels_3d) + + bbox_targets = torch.stack(bbox_targets) + mask_targets = torch.stack(mask_targets) + positive_mask = torch.stack(positive_mask) + negative_mask = torch.stack(negative_mask) + box_loss_weights = positive_mask / (positive_mask.sum() + 1e-6) + + return (bbox_targets, mask_targets, positive_mask, negative_mask, + box_loss_weights, point_targets) + + def get_targets_single(self, points, gt_bboxes_3d, gt_labels_3d): + """Generate targets of PointRCNN RPN head for single batch. + + Args: + points (torch.Tensor): Points of each batch. + gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): Ground truth + boxes of each batch. + gt_labels_3d (torch.Tensor): Labels of each batch. + + Returns: + tuple[torch.Tensor]: Targets of ssd3d head. + """ + gt_bboxes_3d = gt_bboxes_3d.to(points.device) + + valid_gt = gt_labels_3d != -1 + gt_bboxes_3d = gt_bboxes_3d[valid_gt] + gt_labels_3d = gt_labels_3d[valid_gt] + + # transform the bbox coordinate to the pointcloud coordinate + gt_bboxes_3d_tensor = gt_bboxes_3d.tensor.clone() + gt_bboxes_3d_tensor[..., 2] += gt_bboxes_3d_tensor[..., 5] / 2 + + points_mask, assignment = self._assign_targets_by_points_inside( + gt_bboxes_3d, points) + gt_bboxes_3d_tensor = gt_bboxes_3d_tensor[assignment] + mask_targets = gt_labels_3d[assignment] + + bbox_targets = self.bbox_coder.encode(gt_bboxes_3d_tensor, + points[..., 0:3], mask_targets) + + positive_mask = (points_mask.max(1)[0] > 0) + negative_mask = (points_mask.max(1)[0] == 0) + # add ignore_mask + extend_gt_bboxes_3d = gt_bboxes_3d.enlarged_box(self.enlarge_width) + points_mask, _ = self._assign_targets_by_points_inside( + extend_gt_bboxes_3d, points) + negative_mask = (points_mask.max(1)[0] == 0) + + point_targets = points[..., 0:3] + return (bbox_targets, mask_targets, positive_mask, negative_mask, + point_targets) + + def get_bboxes(self, + points, + bbox_preds, + cls_preds, + input_metas, + rescale=False): + """Generate bboxes from RPN head predictions. + + Args: + points (torch.Tensor): Input points. + bbox_preds (dict): Regression predictions from PointRCNN head. + cls_preds (dict): Class scores predictions from PointRCNN head. + input_metas (list[dict]): Point cloud and image's meta info. + rescale (bool, optional): Whether to rescale bboxes. + Defaults to False. + + Returns: + list[tuple[torch.Tensor]]: Bounding boxes, scores and labels. + """ + sem_scores = cls_preds.sigmoid() + obj_scores = sem_scores.max(-1)[0] + object_class = sem_scores.argmax(dim=-1) + + batch_size = sem_scores.shape[0] + results = list() + for b in range(batch_size): + bbox3d = self.bbox_coder.decode(bbox_preds[b], points[b, ..., :3], + object_class[b]) + bbox_selected, score_selected, labels, cls_preds_selected = \ + self.class_agnostic_nms(obj_scores[b], sem_scores[b], bbox3d) + bbox = input_metas[b]['box_type_3d']( + bbox_selected.clone(), + box_dim=bbox_selected.shape[-1], + with_yaw=True) + results.append((bbox, score_selected, labels, cls_preds_selected)) + return results + + def class_agnostic_nms(self, obj_scores, sem_scores, bbox): + nms_cfg = self.test_cfg.nms_cfg if not self.training \ + else self.train_cfg.nms_cfg + if nms_cfg.use_rotate_nms: + nms_func = nms_gpu + else: + nms_func = nms_normal_gpu + + if self.test_cfg.score_thr is not None: + score_thr = self.test_cfg.score_thr + keep = (obj_scores >= score_thr) + obj_scores = obj_scores[keep] + sem_scores = sem_scores[keep] + bbox = bbox[keep] + + if obj_scores.shape[0] > 0: + topk = min(nms_cfg.nms_pre, obj_scores.shape[0]) + obj_scores_nms, indices = torch.topk(obj_scores, k=topk) + bbox_for_nms = bbox[indices] + sem_scores_nms = sem_scores[indices] + + keep = nms_func(bbox_for_nms[:, 0:7], obj_scores_nms, + nms_cfg.iou_thr) + keep = keep[:nms_cfg.nms_post] + + bbox_selected = bbox_for_nms[keep] + score_selected = obj_scores_nms[keep] + cls_preds = sem_scores_nms[keep] + labels = torch.argmax(cls_preds, -1) + + return bbox_selected, score_selected, labels, cls_preds + + def _assign_targets_by_points_inside(self, bboxes_3d, points): + """Compute assignment by checking whether point is inside bbox. + + Args: + bboxes_3d (BaseInstance3DBoxes): Instance of bounding boxes. + points (torch.Tensor): Points of a batch. + + Returns: + tuple[torch.Tensor]: Flags indicating whether each point is + inside bbox and the index of box where each point are in. + """ + # TODO: align points_in_boxes function in each box_structures + num_bbox = bboxes_3d.tensor.shape[0] + if isinstance(bboxes_3d, LiDARInstance3DBoxes): + assignment = bboxes_3d.points_in_boxes(points[:, 0:3]).long() + points_mask = assignment.new_zeros( + [assignment.shape[0], num_bbox + 1]) + assignment[assignment == -1] = num_bbox + points_mask.scatter_(1, assignment.unsqueeze(1), 1) + points_mask = points_mask[:, :-1] + assignment[assignment == num_bbox] = num_bbox - 1 + elif isinstance(bboxes_3d, DepthInstance3DBoxes): + points_mask = bboxes_3d.points_in_boxes(points) + assignment = points_mask.argmax(dim=-1) + else: + raise NotImplementedError('Unsupported bbox type!') + + return points_mask, assignment diff --git a/mmdet3d/models/dense_heads/ssd_3d_head.py b/mmdet3d/models/dense_heads/ssd_3d_head.py index 0ce042fef1..6ab827d5c3 100644 --- a/mmdet3d/models/dense_heads/ssd_3d_head.py +++ b/mmdet3d/models/dense_heads/ssd_3d_head.py @@ -441,7 +441,7 @@ def get_targets_single(self, negative_mask) def get_bboxes(self, points, bbox_preds, input_metas, rescale=False): - """Generate bboxes from sdd3d head predictions. + """Generate bboxes from 3DSSD head predictions. Args: points (torch.Tensor): Input points. diff --git a/mmdet3d/models/detectors/__init__.py b/mmdet3d/models/detectors/__init__.py index 4673617535..b53b19f2f2 100644 --- a/mmdet3d/models/detectors/__init__.py +++ b/mmdet3d/models/detectors/__init__.py @@ -10,6 +10,7 @@ from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN from .mvx_two_stage import MVXTwoStageDetector from .parta2 import PartA2 +from .pointrcnn import PointRCNN from .single_stage_mono3d import SingleStageMono3DDetector from .smoke_mono3d import SMOKEMono3D from .ssd3dnet import SSD3DNet @@ -20,5 +21,5 @@ 'Base3DDetector', 'VoxelNet', 'DynamicVoxelNet', 'MVXTwoStageDetector', 'DynamicMVXFasterRCNN', 'MVXFasterRCNN', 'PartA2', 'VoteNet', 'H3DNet', 'CenterPoint', 'SSD3DNet', 'ImVoteNet', 'SingleStageMono3DDetector', - 'FCOSMono3D', 'ImVoxelNet', 'GroupFree3DNet', 'SMOKEMono3D' + 'FCOSMono3D', 'ImVoxelNet', 'GroupFree3DNet', 'PointRCNN', 'SMOKEMono3D' ] diff --git a/mmdet3d/models/detectors/pointrcnn.py b/mmdet3d/models/detectors/pointrcnn.py new file mode 100644 index 0000000000..93b6b37974 --- /dev/null +++ b/mmdet3d/models/detectors/pointrcnn.py @@ -0,0 +1,144 @@ +import torch + +from mmdet.models import DETECTORS +from .two_stage import TwoStage3DDetector + + +@DETECTORS.register_module() +class PointRCNN(TwoStage3DDetector): + r"""PointRCNN detector. + + Please refer to the `PointRCNN `_ + + Args: + backbone (dict): Config dict of detector's backbone. + neck (dict, optional): Config dict of neck. Defaults to None. + rpn_head (dict, optional): Config of RPN head. Defaults to None. + roi_head (dict, optional): Config of ROI head. Defaults to None. + train_cfg (dict, optional): Train configs. Defaults to None. + test_cfg (dict, optional): Test configs. Defaults to None. + init_cfg (dict, optional): Config of initialization. Defaults to None. + """ + + def __init__(self, + backbone, + neck=None, + rpn_head=None, + roi_head=None, + train_cfg=None, + test_cfg=None, + pretrained=None, + init_cfg=None): + super(PointRCNN, self).__init__( + backbone=backbone, + neck=neck, + rpn_head=rpn_head, + roi_head=roi_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + pretrained=pretrained, + init_cfg=init_cfg) + + def extract_feat(self, points): + """Directly extract features from the backbone+neck. + + Args: + points (torch.Tensor): Input points. + + Returns: + dict: Features from the backbone+neck + """ + x = self.backbone(points) + + if self.with_neck: + x = self.neck(x) + return x + + def forward_train(self, points, img_metas, gt_bboxes_3d, gt_labels_3d): + """Forward of training. + + Args: + points (list[torch.Tensor]): Points of each batch. + img_metas (list[dict]): Meta information of each sample. + gt_bboxes_3d (:obj:`BaseInstance3DBoxes`): gt bboxes of each batch. + gt_labels_3d (list[torch.Tensor]): gt class labels of each batch. + + Returns: + dict: Losses. + """ + losses = dict() + points_cat = torch.stack(points) + x = self.extract_feat(points_cat) + + # features for rcnn + backbone_feats = x['fp_features'].clone() + backbone_xyz = x['fp_xyz'].clone() + rcnn_feats = {'features': backbone_feats, 'points': backbone_xyz} + + bbox_preds, cls_preds = self.rpn_head(x) + + rpn_loss = self.rpn_head.loss( + bbox_preds=bbox_preds, + cls_preds=cls_preds, + points=points, + gt_bboxes_3d=gt_bboxes_3d, + gt_labels_3d=gt_labels_3d, + img_metas=img_metas) + losses.update(rpn_loss) + + bbox_list = self.rpn_head.get_bboxes(points_cat, bbox_preds, cls_preds, + img_metas) + proposal_list = [ + dict( + boxes_3d=bboxes, + scores_3d=scores, + labels_3d=labels, + cls_preds=preds_cls) + for bboxes, scores, labels, preds_cls in bbox_list + ] + rcnn_feats.update({'points_cls_preds': cls_preds}) + + roi_losses = self.roi_head.forward_train(rcnn_feats, img_metas, + proposal_list, gt_bboxes_3d, + gt_labels_3d) + losses.update(roi_losses) + + return losses + + def simple_test(self, points, img_metas, imgs=None, rescale=False): + """Forward of testing. + + Args: + points (list[torch.Tensor]): Points of each sample. + img_metas (list[dict]): Image metas. + rescale (bool, optional): Whether to rescale results. + Defaults to False. + + Returns: + list: Predicted 3d boxes. + """ + points_cat = torch.stack(points) + + x = self.extract_feat(points_cat) + # features for rcnn + backbone_feats = x['fp_features'].clone() + backbone_xyz = x['fp_xyz'].clone() + rcnn_feats = {'features': backbone_feats, 'points': backbone_xyz} + bbox_preds, cls_preds = self.rpn_head(x) + rcnn_feats.update({'points_cls_preds': cls_preds}) + + bbox_list = self.rpn_head.get_bboxes( + points_cat, bbox_preds, cls_preds, img_metas, rescale=rescale) + + proposal_list = [ + dict( + boxes_3d=bboxes, + scores_3d=scores, + labels_3d=labels, + cls_preds=preds_cls) + for bboxes, scores, labels, preds_cls in bbox_list + ] + bbox_results = self.roi_head.simple_test(rcnn_feats, img_metas, + proposal_list) + + return bbox_results diff --git a/mmdet3d/models/necks/pointnet2_fp_neck.py b/mmdet3d/models/necks/pointnet2_fp_neck.py index b8b6b46d1b..1ba2fda9cd 100644 --- a/mmdet3d/models/necks/pointnet2_fp_neck.py +++ b/mmdet3d/models/necks/pointnet2_fp_neck.py @@ -71,7 +71,7 @@ def forward(self, feat_dict): - fp_xyz (torch.Tensor): The coordinates of fp features. - fp_features (torch.Tensor): The features from the last - feature propogation layers. + feature propagation layers. """ sa_xyz, sa_features = self._extract_input(feat_dict) diff --git a/mmdet3d/models/roi_heads/__init__.py b/mmdet3d/models/roi_heads/__init__.py index 509c9ccb61..e607570d71 100644 --- a/mmdet3d/models/roi_heads/__init__.py +++ b/mmdet3d/models/roi_heads/__init__.py @@ -4,10 +4,11 @@ from .h3d_roi_head import H3DRoIHead from .mask_heads import PointwiseSemanticHead, PrimitiveHead from .part_aggregation_roi_head import PartAggregationROIHead +from .point_rcnn_roi_head import PointRCNNRoIHead from .roi_extractors import Single3DRoIAwareExtractor, SingleRoIExtractor __all__ = [ 'Base3DRoIHead', 'PartAggregationROIHead', 'PointwiseSemanticHead', 'Single3DRoIAwareExtractor', 'PartA2BboxHead', 'SingleRoIExtractor', - 'H3DRoIHead', 'PrimitiveHead' + 'H3DRoIHead', 'PrimitiveHead', 'PointRCNNRoIHead' ] diff --git a/mmdet3d/models/roi_heads/bbox_heads/__init__.py b/mmdet3d/models/roi_heads/bbox_heads/__init__.py index 6294f52f4c..fd7a6b04ae 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/__init__.py +++ b/mmdet3d/models/roi_heads/bbox_heads/__init__.py @@ -5,9 +5,10 @@ Shared4Conv1FCBBoxHead) from .h3d_bbox_head import H3DBboxHead from .parta2_bbox_head import PartA2BboxHead +from .point_rcnn_bbox_head import PointRCNNBboxHead __all__ = [ 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead', 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'PartA2BboxHead', - 'H3DBboxHead' + 'H3DBboxHead', 'PointRCNNBboxHead' ] diff --git a/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py index 6b10e0e730..e6e4e2b77b 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py +++ b/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py @@ -285,7 +285,7 @@ def forward(self, seg_feats, part_feats): def loss(self, cls_score, bbox_pred, rois, labels, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, bbox_weights): - """Coumputing losses. + """Computing losses. Args: cls_score (torch.Tensor): Scores of each roi. @@ -461,12 +461,13 @@ def _get_target_single(self, pos_bboxes, pos_gt_bboxes, ious, cfg): return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, bbox_weights) - def get_corner_loss_lidar(self, pred_bbox3d, gt_bbox3d, delta=1): + def get_corner_loss_lidar(self, pred_bbox3d, gt_bbox3d, delta=1.0): """Calculate corner loss of given boxes. Args: pred_bbox3d (torch.FloatTensor): Predicted boxes in shape (N, 7). gt_bbox3d (torch.FloatTensor): Ground truth boxes in shape (N, 7). + delta (float, optional): huber loss threshold. Defaults to 1.0 Returns: torch.FloatTensor: Calculated corner loss in shape (N). @@ -489,8 +490,8 @@ def get_corner_loss_lidar(self, pred_bbox3d, gt_bbox3d, delta=1): torch.norm(pred_box_corners - gt_box_corners_flip, dim=2)) # (N, 8) # huber loss - abs_error = torch.abs(corner_dist) - quadratic = torch.clamp(abs_error, max=delta) + abs_error = corner_dist.abs() + quadratic = abs_error.clamp(max=delta) linear = (abs_error - quadratic) corner_loss = 0.5 * quadratic**2 + delta * linear @@ -540,13 +541,13 @@ def get_bboxes(self, cur_box_prob = class_pred[batch_id] cur_rcnn_boxes3d = rcnn_boxes3d[roi_batch_id == batch_id] - selected = self.multi_class_nms(cur_box_prob, cur_rcnn_boxes3d, - cfg.score_thr, cfg.nms_thr, - img_metas[batch_id], - cfg.use_rotate_nms) - selected_bboxes = cur_rcnn_boxes3d[selected] - selected_label_preds = cur_class_labels[selected] - selected_scores = cur_cls_score[selected] + keep = self.multi_class_nms(cur_box_prob, cur_rcnn_boxes3d, + cfg.score_thr, cfg.nms_thr, + img_metas[batch_id], + cfg.use_rotate_nms) + selected_bboxes = cur_rcnn_boxes3d[keep] + selected_label_preds = cur_class_labels[keep] + selected_scores = cur_cls_score[keep] result_list.append( (img_metas[batch_id]['box_type_3d'](selected_bboxes, @@ -618,6 +619,6 @@ def multi_class_nms(self, dtype=torch.int64, device=box_preds.device)) - selected = torch.cat( + keep = torch.cat( selected_list, dim=0) if len(selected_list) > 0 else [] - return selected + return keep diff --git a/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py new file mode 100644 index 0000000000..2e1c2d5a49 --- /dev/null +++ b/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py @@ -0,0 +1,580 @@ +import numpy as np +import torch +from mmcv.cnn import ConvModule, normal_init +from mmcv.cnn.bricks import build_conv_layer +from mmcv.runner import BaseModule +from torch import nn as nn + +from mmdet3d.core.bbox.structures import (LiDARInstance3DBoxes, + rotation_3d_in_axis, xywhr2xyxyr) +from mmdet3d.models.builder import build_loss +from mmdet3d.ops import build_sa_module +from mmdet3d.ops.iou3d.iou3d_utils import nms_gpu, nms_normal_gpu +from mmdet.core import build_bbox_coder, multi_apply +from mmdet.models import HEADS + + +@HEADS.register_module() +class PointRCNNBboxHead(BaseModule): + """PointRCNN RoI Bbox head. + + Args: + num_classes (int): The number of classes to prediction. + in_channels (int): Input channels of point features. + mlp_channels (list[int]): the number of mlp channels + pred_layer_cfg (dict, optional): Config of classfication and + regression prediction layers. Defaults to None. + num_points (tuple, optional): The number of points which each SA + module samples. Defaults to (128, 32, -1). + radius (tuple, optional): Sampling radius of each SA module. + Defaults to (0.2, 0.4, 100). + num_samples (tuple, optional): The number of samples for ball query + in each SA module. Defaults to (64, 64, 64). + sa_channels (tuple, optional): Out channels of each mlp in SA module. + Defaults to ((128, 128, 128), (128, 128, 256), (256, 256, 512)). + bbox_coder (dict, optional): Config dict of box coders. + Defaults to dict(type='DeltaXYZWLHRBBoxCoder'). + sa_cfg (dict, optional): Config of set abstraction module, which may + contain the following keys and values: + + - pool_mod (str): Pool method ('max' or 'avg') for SA modules. + - use_xyz (bool): Whether to use xyz as a part of features. + - normalize_xyz (bool): Whether to normalize xyz with radii in + each SA module. + Defaults to dict(type='PointSAModule', pool_mod='max', + use_xyz=True). + conv_cfg (dict, optional): Config dict of convolutional layers. + Defaults to dict(type='Conv1d'). + norm_cfg (dict, optional): Config dict of normalization layers. + Defaults to dict(type='BN1d'). + act_cfg (dict, optional): Config dict of activation layers. + Defaults to dict(type='ReLU'). + bias (str, optional): Type of bias. Defaults to 'auto'. + loss_bbox (dict, optional): Config of regression loss function. + Defaults to dict(type='SmoothL1Loss', beta=1.0 / 9.0, + reduction='sum', loss_weight=1.0). + loss_cls (dict, optional): Config of classification loss function. + Defaults to dict(type='CrossEntropyLoss', use_sigmoid=True, + reduction='sum', loss_weight=1.0). + with_corner_loss (bool, optional): Whether using corner loss. + Defaults to True. + init_cfg (dict, optional): Config of initialization. Defaults to None. + """ + + def __init__( + self, + num_classes, + in_channels, + mlp_channels, + pred_layer_cfg=None, + num_points=(128, 32, -1), + radius=(0.2, 0.4, 100), + num_samples=(64, 64, 64), + sa_channels=((128, 128, 128), (128, 128, 256), (256, 256, 512)), + bbox_coder=dict(type='DeltaXYZWLHRBBoxCoder'), + sa_cfg=dict(type='PointSAModule', pool_mod='max', use_xyz=True), + conv_cfg=dict(type='Conv1d'), + norm_cfg=dict(type='BN1d'), + act_cfg=dict(type='ReLU'), + bias='auto', + loss_bbox=dict( + type='SmoothL1Loss', + beta=1.0 / 9.0, + reduction='sum', + loss_weight=1.0), + loss_cls=dict( + type='CrossEntropyLoss', + use_sigmoid=True, + reduction='sum', + loss_weight=1.0), + with_corner_loss=True, + init_cfg=None): + super(PointRCNNBboxHead, self).__init__(init_cfg=init_cfg) + self.num_classes = num_classes + self.num_sa = len(sa_channels) + self.with_corner_loss = with_corner_loss + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.bias = bias + + self.loss_bbox = build_loss(loss_bbox) + self.loss_cls = build_loss(loss_cls) + self.bbox_coder = build_bbox_coder(bbox_coder) + self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False) + + self.in_channels = in_channels + mlp_channels = [self.in_channels] + mlp_channels + shared_mlps = nn.Sequential() + for i in range(len(mlp_channels) - 1): + shared_mlps.add_module( + f'layer{i}', + ConvModule( + mlp_channels[i], + mlp_channels[i + 1], + kernel_size=(1, 1), + stride=(1, 1), + inplace=False, + conv_cfg=dict(type='Conv2d'))) + self.xyz_up_layer = nn.Sequential(*shared_mlps) + + c_out = mlp_channels[-1] + self.merge_down_layer = ConvModule( + c_out * 2, + c_out, + kernel_size=(1, 1), + stride=(1, 1), + inplace=False, + conv_cfg=dict(type='Conv2d')) + + pre_channels = c_out + + self.SA_modules = nn.ModuleList() + sa_in_channel = pre_channels + + for sa_index in range(self.num_sa): + cur_sa_mlps = list(sa_channels[sa_index]) + cur_sa_mlps = [sa_in_channel] + cur_sa_mlps + sa_out_channel = cur_sa_mlps[-1] + + cur_num_points = num_points[sa_index] + if cur_num_points <= 0: + cur_num_points = None + self.SA_modules.append( + build_sa_module( + num_point=cur_num_points, + radius=radius[sa_index], + num_sample=num_samples[sa_index], + mlp_channels=cur_sa_mlps, + cfg=sa_cfg)) + sa_in_channel = sa_out_channel + self.cls_convs = self._add_conv_branch( + pred_layer_cfg.in_channels, pred_layer_cfg.cls_conv_channels) + self.reg_convs = self._add_conv_branch( + pred_layer_cfg.in_channels, pred_layer_cfg.reg_conv_channels) + + prev_channel = pred_layer_cfg.cls_conv_channels[-1] + self.conv_cls = build_conv_layer( + self.conv_cfg, + in_channels=prev_channel, + out_channels=self.num_classes, + kernel_size=1) + prev_channel = pred_layer_cfg.reg_conv_channels[-1] + self.conv_reg = build_conv_layer( + self.conv_cfg, + in_channels=prev_channel, + out_channels=self.bbox_coder.code_size * self.num_classes, + kernel_size=1) + + if init_cfg is None: + self.init_cfg = dict(type='Xavier', layer=['Conv2d', 'Conv1d']) + + def _add_conv_branch(self, in_channels, conv_channels): + """Add shared or separable branch. + + Args: + in_channels (int): Input feature channel. + conv_channels (tuple): Middle feature channels. + """ + conv_spec = [in_channels] + list(conv_channels) + # add branch specific conv layers + conv_layers = nn.Sequential() + for i in range(len(conv_spec) - 1): + conv_layers.add_module( + f'layer{i}', + ConvModule( + conv_spec[i], + conv_spec[i + 1], + kernel_size=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + bias=self.bias, + inplace=True)) + return conv_layers + + def init_weights(self): + """Initialize weights of the head.""" + super().init_weights() + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Conv1d): + if m.bias is not None: + nn.init.constant_(m.bias, 0) + normal_init(self.conv_reg.weight, mean=0, std=0.001) + + def forward(self, feats): + """Forward pass. + + Args: + feats (torch.Torch): Features from RCNN modules. + + Returns: + tuple[torch.Tensor]: Score of class and bbox predictions. + """ + input_data = feats.clone().detach() + xyz_input = input_data[..., 0:self.in_channels].transpose( + 1, 2).unsqueeze(dim=3).contiguous().clone().detach() + xyz_features = self.xyz_up_layer(xyz_input) + rpn_features = input_data[..., self.in_channels:].transpose( + 1, 2).unsqueeze(dim=3) + merged_features = torch.cat((xyz_features, rpn_features), dim=1) + merged_features = self.merge_down_layer(merged_features) + l_xyz, l_features = [input_data[..., 0:3].contiguous()], \ + [merged_features.squeeze(dim=3)] + for i in range(len(self.SA_modules)): + li_xyz, li_features, cur_indices = \ + self.SA_modules[i](l_xyz[i], l_features[i]) + l_xyz.append(li_xyz) + l_features.append(li_features) + + shared_features = l_features[-1] + x_cls = shared_features + x_reg = shared_features + x_cls = self.cls_convs(x_cls) + rcnn_cls = self.conv_cls(x_cls) + x_reg = self.reg_convs(x_reg) + rcnn_reg = self.conv_reg(x_reg) + rcnn_cls = rcnn_cls.transpose(1, 2).contiguous().squeeze(dim=1) + rcnn_reg = rcnn_reg.transpose(1, 2).contiguous().squeeze(dim=1) + return (rcnn_cls, rcnn_reg) + + def loss(self, cls_score, bbox_pred, rois, labels, bbox_targets, + pos_gt_bboxes, reg_mask, label_weights, bbox_weights): + """Computing losses. + + Args: + cls_score (torch.Tensor): Scores of each RoI. + bbox_pred (torch.Tensor): Predictions of bboxes. + rois (torch.Tensor): RoI bboxes. + labels (torch.Tensor): Labels of class. + bbox_targets (torch.Tensor): Target of positive bboxes. + pos_gt_bboxes (torch.Tensor): Ground truths of positive bboxes. + reg_mask (torch.Tensor): Mask for positive bboxes. + label_weights (torch.Tensor): Weights of class loss. + bbox_weights (torch.Tensor): Weights of bbox loss. + + Returns: + dict: Computed losses. + + - loss_cls (torch.Tensor): Loss of classes. + - loss_bbox (torch.Tensor): Loss of bboxes. + - loss_corner (torch.Tensor): Loss of corners. + """ + losses = dict() + rcnn_batch_size = cls_score.shape[0] + # calculate class loss + cls_flat = cls_score.view(-1) + loss_cls = self.loss_cls(cls_flat, labels, label_weights) + losses['loss_cls'] = loss_cls + + # calculate regression loss + code_size = self.bbox_coder.code_size + pos_inds = (reg_mask > 0) + if pos_inds.any() == 0: + # fake a part loss + losses['loss_bbox'] = loss_cls.new_tensor(0) + if self.with_corner_loss: + losses['loss_corner'] = loss_cls.new_tensor(0) + else: + pos_bbox_pred = bbox_pred.view(rcnn_batch_size, + -1)[pos_inds].clone() + bbox_weights_flat = bbox_weights[pos_inds].view(-1, 1).repeat( + 1, pos_bbox_pred.shape[-1]) + loss_bbox = self.loss_bbox( + pos_bbox_pred.unsqueeze(dim=0), + bbox_targets.unsqueeze(dim=0).detach(), + bbox_weights_flat.unsqueeze(dim=0)) + losses['loss_bbox'] = loss_bbox + + if self.with_corner_loss: + rois = rois.detach() + pos_roi_boxes3d = rois[..., 1:].view(-1, code_size)[pos_inds] + pos_roi_boxes3d = pos_roi_boxes3d.view(-1, code_size) + batch_anchors = pos_roi_boxes3d.clone().detach() + pos_rois_rotation = pos_roi_boxes3d[..., 6].view(-1) + roi_xyz = pos_roi_boxes3d[..., 0:3].view(-1, 3) + batch_anchors[..., 0:3] = 0 + # decode boxes + pred_boxes3d = self.bbox_coder.decode( + batch_anchors, + pos_bbox_pred.view(-1, code_size)).view(-1, code_size) + + pred_boxes3d[..., 0:3] = rotation_3d_in_axis( + pred_boxes3d[..., 0:3].unsqueeze(1), (pos_rois_rotation), + axis=2).squeeze(1) + + pred_boxes3d[:, 0:3] += roi_xyz + + # calculate corner loss + loss_corner = self.get_corner_loss_lidar( + pred_boxes3d, pos_gt_bboxes) + + losses['loss_corner'] = loss_corner + + return losses + + def get_corner_loss_lidar(self, pred_bbox3d, gt_bbox3d, delta=1.0): + """Calculate corner loss of given boxes. + + Args: + pred_bbox3d (torch.FloatTensor): Predicted boxes in shape (N, 7). + gt_bbox3d (torch.FloatTensor): Ground truth boxes in shape (N, 7). + delta (float, optional): huber loss threshold. Defaults to 1.0 + + Returns: + torch.FloatTensor: Calculated corner loss in shape (N). + """ + assert pred_bbox3d.shape[0] == gt_bbox3d.shape[0] + + # This is a little bit hack here because we assume the box for + # PointRCNN is in LiDAR coordinates + + gt_boxes_structure = LiDARInstance3DBoxes(gt_bbox3d) + pred_box_corners = LiDARInstance3DBoxes(pred_bbox3d).corners + gt_box_corners = gt_boxes_structure.corners + + # This flip only changes the heading direction of GT boxes + gt_bbox3d_flip = gt_boxes_structure.clone() + gt_bbox3d_flip.tensor[:, 6] += np.pi + gt_box_corners_flip = gt_bbox3d_flip.corners + + corner_dist = torch.min( + torch.norm(pred_box_corners - gt_box_corners, dim=2), + torch.norm(pred_box_corners - gt_box_corners_flip, dim=2)) + # huber loss + abs_error = corner_dist.abs() + quadratic = abs_error.clamp(max=delta) + linear = (abs_error - quadratic) + corner_loss = 0.5 * quadratic**2 + delta * linear + return corner_loss.mean(dim=1) + + def get_targets(self, sampling_results, rcnn_train_cfg, concat=True): + """Generate targets. + + Args: + sampling_results (list[:obj:`SamplingResult`]): + Sampled results from rois. + rcnn_train_cfg (:obj:`ConfigDict`): Training config of rcnn. + concat (bool, optional): Whether to concatenate targets between + batches. Defaults to True. + + Returns: + tuple[torch.Tensor]: Targets of boxes and class prediction. + """ + pos_bboxes_list = [res.pos_bboxes for res in sampling_results] + pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results] + iou_list = [res.iou for res in sampling_results] + targets = multi_apply( + self._get_target_single, + pos_bboxes_list, + pos_gt_bboxes_list, + iou_list, + cfg=rcnn_train_cfg) + (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, + bbox_weights) = targets + + if concat: + label = torch.cat(label, 0) + bbox_targets = torch.cat(bbox_targets, 0) + pos_gt_bboxes = torch.cat(pos_gt_bboxes, 0) + reg_mask = torch.cat(reg_mask, 0) + + label_weights = torch.cat(label_weights, 0) + label_weights /= torch.clamp(label_weights.sum(), min=1.0) + + bbox_weights = torch.cat(bbox_weights, 0) + bbox_weights /= torch.clamp(bbox_weights.sum(), min=1.0) + + return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, + bbox_weights) + + def _get_target_single(self, pos_bboxes, pos_gt_bboxes, ious, cfg): + """Generate training targets for a single sample. + + Args: + pos_bboxes (torch.Tensor): Positive boxes with shape + (N, 7). + pos_gt_bboxes (torch.Tensor): Ground truth boxes with shape + (M, 7). + ious (torch.Tensor): IoU between `pos_bboxes` and `pos_gt_bboxes` + in shape (N, M). + cfg (dict): Training configs. + + Returns: + tuple[torch.Tensor]: Target for positive boxes. + (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, + bbox_weights) + """ + cls_pos_mask = ious > cfg.cls_pos_thr + cls_neg_mask = ious < cfg.cls_neg_thr + interval_mask = (cls_pos_mask == 0) & (cls_neg_mask == 0) + # iou regression target + label = (cls_pos_mask > 0).float() + label[interval_mask] = (ious[interval_mask] - cfg.cls_neg_thr) / \ + (cfg.cls_pos_thr - cfg.cls_neg_thr) + # label weights + label_weights = (label >= 0).float() + # box regression target + reg_mask = pos_bboxes.new_zeros(ious.size(0)).long() + reg_mask[0:pos_gt_bboxes.size(0)] = 1 + bbox_weights = (reg_mask > 0).float() + if reg_mask.bool().any(): + pos_gt_bboxes_ct = pos_gt_bboxes.clone().detach() + roi_center = pos_bboxes[..., 0:3] + roi_ry = pos_bboxes[..., 6] % (2 * np.pi) + + # canonical transformation + pos_gt_bboxes_ct[..., 0:3] -= roi_center + pos_gt_bboxes_ct[..., 6] -= roi_ry + pos_gt_bboxes_ct[..., 0:3] = rotation_3d_in_axis( + pos_gt_bboxes_ct[..., 0:3].unsqueeze(1), -(roi_ry), + axis=2).squeeze(1) + + # flip orientation if gt have opposite orientation + ry_label = pos_gt_bboxes_ct[..., 6] % (2 * np.pi) # 0 ~ 2pi + is_opposite = (ry_label > np.pi * 0.5) & (ry_label < np.pi * 1.5) + ry_label[is_opposite] = (ry_label[is_opposite] + np.pi) % ( + 2 * np.pi) # (0 ~ pi/2, 3pi/2 ~ 2pi) + flag = ry_label > np.pi + ry_label[flag] = ry_label[flag] - np.pi * 2 # (-pi/2, pi/2) + ry_label = torch.clamp(ry_label, min=-np.pi / 2, max=np.pi / 2) + pos_gt_bboxes_ct[..., 6] = ry_label + + rois_anchor = pos_bboxes.clone().detach() + rois_anchor[:, 0:3] = 0 + rois_anchor[:, 6] = 0 + bbox_targets = self.bbox_coder.encode(rois_anchor, + pos_gt_bboxes_ct) + else: + # no fg bbox + bbox_targets = pos_gt_bboxes.new_empty((0, 7)) + + return (label, bbox_targets, pos_gt_bboxes, reg_mask, label_weights, + bbox_weights) + + def get_bboxes(self, + rois, + cls_score, + bbox_pred, + class_labels, + img_metas, + cfg=None): + """Generate bboxes from bbox head predictions. + + Args: + rois (torch.Tensor): RoI bounding boxes. + cls_score (torch.Tensor): Scores of bounding boxes. + bbox_pred (torch.Tensor): Bounding boxes predictions + class_labels (torch.Tensor): Label of classes + img_metas (list[dict]): Point cloud and image's meta info. + cfg (:obj:`ConfigDict`, optional): Testing config. + Defaults to None. + + Returns: + list[tuple]: Decoded bbox, scores and labels after nms. + """ + roi_batch_id = rois[..., 0] + roi_boxes = rois[..., 1:] # boxes without batch id + batch_size = int(roi_batch_id.max().item() + 1) + + # decode boxes + roi_ry = roi_boxes[..., 6].view(-1) + roi_xyz = roi_boxes[..., 0:3].view(-1, 3) + local_roi_boxes = roi_boxes.clone().detach() + local_roi_boxes[..., 0:3] = 0 + rcnn_boxes3d = self.bbox_coder.decode(local_roi_boxes, bbox_pred) + rcnn_boxes3d[..., 0:3] = rotation_3d_in_axis( + rcnn_boxes3d[..., 0:3].unsqueeze(1), (roi_ry), axis=2).squeeze(1) + rcnn_boxes3d[:, 0:3] += roi_xyz + + # post processing + result_list = [] + for batch_id in range(batch_size): + cur_class_labels = class_labels[batch_id] + cur_cls_score = cls_score[roi_batch_id == batch_id].view(-1) + + cur_box_prob = cls_score[batch_id] + cur_box_prob = cur_cls_score.unsqueeze(1) + cur_rcnn_boxes3d = rcnn_boxes3d[roi_batch_id == batch_id] + keep = self.multi_class_nms(cur_box_prob, cur_rcnn_boxes3d, + cfg.score_thr, cfg.nms_thr, + img_metas[batch_id], + cfg.use_rotate_nms) + selected_bboxes = cur_rcnn_boxes3d[keep] + selected_label_preds = cur_class_labels[keep] + selected_scores = cur_cls_score[keep] + + result_list.append( + (img_metas[batch_id]['box_type_3d'](selected_bboxes, + self.bbox_coder.code_size), + selected_scores, selected_label_preds)) + return result_list + + def multi_class_nms(self, + box_probs, + box_preds, + score_thr, + nms_thr, + input_meta, + use_rotate_nms=True): + """Multi-class NMS for box head. + + Note: + This function has large overlap with the `box3d_multiclass_nms` + implemented in `mmdet3d.core.post_processing`. We are considering + merging these two functions in the future. + + Args: + box_probs (torch.Tensor): Predicted boxes probabitilies in + shape (N,). + box_preds (torch.Tensor): Predicted boxes in shape (N, 7+C). + score_thr (float): Threshold of scores. + nms_thr (float): Threshold for NMS. + input_meta (dict): Meta information of the current sample. + use_rotate_nms (bool, optional): Whether to use rotated nms. + Defaults to True. + + Returns: + torch.Tensor: Selected indices. + """ + if use_rotate_nms: + nms_func = nms_gpu + else: + nms_func = nms_normal_gpu + + assert box_probs.shape[ + 1] == self.num_classes, f'box_probs shape: {str(box_probs.shape)}' + selected_list = [] + selected_labels = [] + boxes_for_nms = xywhr2xyxyr(input_meta['box_type_3d']( + box_preds, self.bbox_coder.code_size).bev) + + score_thresh = score_thr if isinstance( + score_thr, list) else [score_thr for x in range(self.num_classes)] + nms_thresh = nms_thr if isinstance( + nms_thr, list) else [nms_thr for x in range(self.num_classes)] + for k in range(0, self.num_classes): + class_scores_keep = box_probs[:, k] >= score_thresh[k] + + if class_scores_keep.int().sum() > 0: + original_idxs = class_scores_keep.nonzero( + as_tuple=False).view(-1) + cur_boxes_for_nms = boxes_for_nms[class_scores_keep] + cur_rank_scores = box_probs[class_scores_keep, k] + + cur_selected = nms_func(cur_boxes_for_nms, cur_rank_scores, + nms_thresh[k]) + + if cur_selected.shape[0] == 0: + continue + selected_list.append(original_idxs[cur_selected]) + selected_labels.append( + torch.full([cur_selected.shape[0]], + k + 1, + dtype=torch.int64, + device=box_preds.device)) + + keep = torch.cat( + selected_list, dim=0) if len(selected_list) > 0 else [] + return keep diff --git a/mmdet3d/models/roi_heads/point_rcnn_roi_head.py b/mmdet3d/models/roi_heads/point_rcnn_roi_head.py new file mode 100644 index 0000000000..daaf4b22cf --- /dev/null +++ b/mmdet3d/models/roi_heads/point_rcnn_roi_head.py @@ -0,0 +1,286 @@ +import torch +from torch.nn import functional as F + +from mmdet3d.core import AssignResult +from mmdet3d.core.bbox import bbox3d2result, bbox3d2roi +from mmdet.core import build_assigner, build_sampler +from mmdet.models import HEADS +from ..builder import build_head, build_roi_extractor +from .base_3droi_head import Base3DRoIHead + + +@HEADS.register_module() +class PointRCNNRoIHead(Base3DRoIHead): + """RoI head for PointRCNN. + + Args: + bbox_head (dict): Config of bbox_head. + point_roi_extractor (dict): Config of RoI extractor. + train_cfg (dict): Train configs. + test_cfg (dict): Test configs. + depth_normalizer (float, optional): Normalize depth feature. + Defaults to 70.0. + init_cfg (dict, optional): Config of initialization. Defaults to None. + """ + + def __init__(self, + bbox_head, + point_roi_extractor, + train_cfg, + test_cfg, + depth_normalizer=70.0, + pretrained=None, + init_cfg=None): + super(PointRCNNRoIHead, self).__init__( + bbox_head=bbox_head, + train_cfg=train_cfg, + test_cfg=test_cfg, + pretrained=pretrained, + init_cfg=init_cfg) + self.depth_normalizer = depth_normalizer + + if point_roi_extractor is not None: + self.point_roi_extractor = build_roi_extractor(point_roi_extractor) + + self.init_assigner_sampler() + + def init_bbox_head(self, bbox_head): + """Initialize box head. + + Args: + bbox_head (dict): Config dict of RoI Head. + """ + self.bbox_head = build_head(bbox_head) + + def init_mask_head(self): + """Initialize maek head.""" + pass + + def init_assigner_sampler(self): + """Initialize assigner and sampler.""" + self.bbox_assigner = None + self.bbox_sampler = None + if self.train_cfg: + if isinstance(self.train_cfg.assigner, dict): + self.bbox_assigner = build_assigner(self.train_cfg.assigner) + elif isinstance(self.train_cfg.assigner, list): + self.bbox_assigner = [ + build_assigner(res) for res in self.train_cfg.assigner + ] + self.bbox_sampler = build_sampler(self.train_cfg.sampler) + + def forward_train(self, feats_dict, input_metas, proposal_list, + gt_bboxes_3d, gt_labels_3d): + """Training forward function of PointRCNNRoIHead. + + Args: + feats_dict (dict): Contains features from the first stage. + imput_metas (list[dict]): Meta info of each input. + proposal_list (list[dict]): Proposal information from rpn. + The dictionary should contain the following keys: + + - boxes_3d (:obj:`BaseInstance3DBoxes`): Proposal bboxes + - labels_3d (torch.Tensor): Labels of proposals + gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): + GT bboxes of each sample. The bboxes are encapsulated + by 3D box structures. + gt_labels_3d (list[LongTensor]): GT labels of each sample. + + Returns: + dict: Losses from RoI RCNN head. + - loss_bbox (torch.Tensor): Loss of bboxes + """ + features = feats_dict['features'] + points = feats_dict['points'] + point_cls_preds = feats_dict['points_cls_preds'] + sem_scores = point_cls_preds.sigmoid() + point_scores = sem_scores.max(-1)[0] + + sample_results = self._assign_and_sample(proposal_list, gt_bboxes_3d, + gt_labels_3d) + + # concat the depth, semantic features and backbone features + features = features.transpose(1, 2).contiguous() + point_depths = points.norm(dim=2) / self.depth_normalizer - 0.5 + features_list = [ + point_scores.unsqueeze(2), + point_depths.unsqueeze(2), features + ] + features = torch.cat(features_list, dim=2) + + bbox_results = self._bbox_forward_train(features, points, + sample_results) + losses = dict() + losses.update(bbox_results['loss_bbox']) + + return losses + + def simple_test(self, feats_dict, img_metas, proposal_list, **kwargs): + """Simple testing forward function of PointRCNNRoIHead. + + Note: + This function assumes that the batch size is 1 + + Args: + feats_dict (dict): Contains features from the first stage. + img_metas (list[dict]): Meta info of each image. + proposal_list (list[dict]): Proposal information from rpn. + + Returns: + dict: Bbox results of one frame. + """ + rois = bbox3d2roi([res['boxes_3d'].tensor for res in proposal_list]) + labels_3d = [res['labels_3d'] for res in proposal_list] + + features = feats_dict['features'] + points = feats_dict['points'] + point_cls_preds = feats_dict['points_cls_preds'] + sem_scores = point_cls_preds.sigmoid() + point_scores = sem_scores.max(-1)[0] + + features = features.transpose(1, 2).contiguous() + point_depths = points.norm(dim=2) / self.depth_normalizer - 0.5 + features_list = [ + point_scores.unsqueeze(2), + point_depths.unsqueeze(2), features + ] + + features = torch.cat(features_list, dim=2) + batch_size = features.shape[0] + bbox_results = self._bbox_forward(features, points, batch_size, rois) + object_score = bbox_results['cls_score'].sigmoid() + bbox_list = self.bbox_head.get_bboxes( + rois, + object_score, + bbox_results['bbox_pred'], + labels_3d, + img_metas, + cfg=self.test_cfg) + + bbox_results = [ + bbox3d2result(bboxes, scores, labels) + for bboxes, scores, labels in bbox_list + ] + return bbox_results + + def _bbox_forward_train(self, features, points, sampling_results): + """Forward training function of roi_extractor and bbox_head. + + Args: + features (torch.Tensor): Backbone features with depth and \ + semantic features. + points (torch.Tensor): Pointcloud. + sampling_results (:obj:`SamplingResult`): Sampled results used + for training. + + Returns: + dict: Forward results including losses and predictions. + """ + rois = bbox3d2roi([res.bboxes for res in sampling_results]) + batch_size = features.shape[0] + bbox_results = self._bbox_forward(features, points, batch_size, rois) + bbox_targets = self.bbox_head.get_targets(sampling_results, + self.train_cfg) + + loss_bbox = self.bbox_head.loss(bbox_results['cls_score'], + bbox_results['bbox_pred'], rois, + *bbox_targets) + + bbox_results.update(loss_bbox=loss_bbox) + return bbox_results + + def _bbox_forward(self, features, points, batch_size, rois): + """Forward function of roi_extractor and bbox_head used in both + training and testing. + + Args: + features (torch.Tensor): Backbone features with depth and + semantic features. + points (torch.Tensor): Pointcloud. + batch_size (int): Batch size. + rois (torch.Tensor): RoI boxes. + + Returns: + dict: Contains predictions of bbox_head and + features of roi_extractor. + """ + pooled_point_feats = self.point_roi_extractor(features, points, + batch_size, rois) + + cls_score, bbox_pred = self.bbox_head(pooled_point_feats) + bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred) + return bbox_results + + def _assign_and_sample(self, proposal_list, gt_bboxes_3d, gt_labels_3d): + """Assign and sample proposals for training. + + Args: + proposal_list (list[dict]): Proposals produced by RPN. + gt_bboxes_3d (list[:obj:`BaseInstance3DBoxes`]): Ground truth + boxes. + gt_labels_3d (list[torch.Tensor]): Ground truth labels + + Returns: + list[:obj:`SamplingResult`]: Sampled results of each training + sample. + """ + sampling_results = [] + # bbox assign + for batch_idx in range(len(proposal_list)): + cur_proposal_list = proposal_list[batch_idx] + cur_boxes = cur_proposal_list['boxes_3d'] + cur_labels_3d = cur_proposal_list['labels_3d'] + cur_gt_bboxes = gt_bboxes_3d[batch_idx].to(cur_boxes.device) + cur_gt_labels = gt_labels_3d[batch_idx] + batch_num_gts = 0 + # 0 is bg + batch_gt_indis = cur_gt_labels.new_full((len(cur_boxes), ), 0) + batch_max_overlaps = cur_boxes.tensor.new_zeros(len(cur_boxes)) + # -1 is bg + batch_gt_labels = cur_gt_labels.new_full((len(cur_boxes), ), -1) + + # each class may have its own assigner + if isinstance(self.bbox_assigner, list): + for i, assigner in enumerate(self.bbox_assigner): + gt_per_cls = (cur_gt_labels == i) + pred_per_cls = (cur_labels_3d == i) + cur_assign_res = assigner.assign( + cur_boxes.tensor[pred_per_cls], + cur_gt_bboxes.tensor[gt_per_cls], + gt_labels=cur_gt_labels[gt_per_cls]) + # gather assign_results in different class into one result + batch_num_gts += cur_assign_res.num_gts + # gt inds (1-based) + gt_inds_arange_pad = gt_per_cls.nonzero( + as_tuple=False).view(-1) + 1 + # pad 0 for indice unassigned + gt_inds_arange_pad = F.pad( + gt_inds_arange_pad, (1, 0), mode='constant', value=0) + # pad -1 for indice ignore + gt_inds_arange_pad = F.pad( + gt_inds_arange_pad, (1, 0), mode='constant', value=-1) + # convert to 0~gt_num+2 for indices + gt_inds_arange_pad += 1 + # now 0 is bg, >1 is fg in batch_gt_indis + batch_gt_indis[pred_per_cls] = gt_inds_arange_pad[ + cur_assign_res.gt_inds + 1] - 1 + batch_max_overlaps[ + pred_per_cls] = cur_assign_res.max_overlaps + batch_gt_labels[pred_per_cls] = cur_assign_res.labels + + assign_result = AssignResult(batch_num_gts, batch_gt_indis, + batch_max_overlaps, + batch_gt_labels) + else: # for single class + assign_result = self.bbox_assigner.assign( + cur_boxes.tensor, + cur_gt_bboxes.tensor, + gt_labels=cur_gt_labels) + + # sample boxes + sampling_result = self.bbox_sampler.sample(assign_result, + cur_boxes.tensor, + cur_gt_bboxes.tensor, + cur_gt_labels) + sampling_results.append(sampling_result) + return sampling_results diff --git a/mmdet3d/ops/__init__.py b/mmdet3d/ops/__init__.py index 1dafa428ac..1530f963b3 100644 --- a/mmdet3d/ops/__init__.py +++ b/mmdet3d/ops/__init__.py @@ -20,6 +20,7 @@ build_sa_module) from .roiaware_pool3d import (RoIAwarePool3d, points_in_boxes_all, points_in_boxes_cpu, points_in_boxes_part) +from .roipoint_pool3d import RoIPointPool3d from .sparse_block import (SparseBasicBlock, SparseBottleneck, make_sparse_convmodule) from .voxel import DynamicScatter, Voxelization, dynamic_scatter, voxelization @@ -39,5 +40,5 @@ 'get_compiler_version', 'assign_score_withk', 'get_compiling_cuda_version', 'Points_Sampler', 'build_sa_module', 'PAConv', 'PAConvCUDA', 'PAConvSAModuleMSG', 'PAConvSAModule', 'PAConvCUDASAModule', - 'PAConvCUDASAModuleMSG' + 'PAConvCUDASAModuleMSG', 'RoIPointPool3d' ] diff --git a/tests/test_models/test_detectors.py b/tests/test_models/test_detectors.py index cb788252af..1aad0fb27d 100644 --- a/tests/test_models/test_detectors.py +++ b/tests/test_models/test_detectors.py @@ -472,6 +472,35 @@ def test_imvoxelnet(): assert labels_3d.shape[0] >= 0 +def test_pointrcnn(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + pointrcnn_cfg = _get_detector_cfg( + 'pointrcnn/pointrcnn_2x8_kitti-3d-3classes.py') + self = build_detector(pointrcnn_cfg).cuda() + points_0 = torch.rand([1000, 4], device='cuda') + points_1 = torch.rand([1000, 4], device='cuda') + points = [points_0, points_1] + + img_meta_0 = dict(box_type_3d=LiDARInstance3DBoxes) + img_meta_1 = dict(box_type_3d=LiDARInstance3DBoxes) + img_metas = [img_meta_0, img_meta_1] + gt_bbox_0 = LiDARInstance3DBoxes(torch.rand([10, 7], device='cuda')) + gt_bbox_1 = LiDARInstance3DBoxes(torch.rand([10, 7], device='cuda')) + gt_bboxes = [gt_bbox_0, gt_bbox_1] + gt_labels_0 = torch.randint(0, 3, [10], device='cuda') + gt_labels_1 = torch.randint(0, 3, [10], device='cuda') + gt_labels = [gt_labels_0, gt_labels_1] + + # test_forward_train + losses = self.forward_train(points, img_metas, gt_bboxes, gt_labels) + assert losses['bbox_loss'] >= 0 + assert losses['semantic_loss'] >= 0 + assert losses['loss_cls'] >= 0 + assert losses['loss_bbox'] >= 0 + assert losses['loss_corner'] >= 0 + + def test_smoke(): if not torch.cuda.is_available(): pytest.skip('test requires GPU and torch+cuda') diff --git a/tests/test_models/test_heads/test_heads.py b/tests/test_models/test_heads/test_heads.py index fb6e64d87f..7aaec8a5ec 100644 --- a/tests/test_models/test_heads/test_heads.py +++ b/tests/test_models/test_heads/test_heads.py @@ -1,5 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy +import mmcv import numpy as np import pytest import random @@ -116,6 +117,23 @@ def _get_pts_bbox_head_cfg(fname): return pts_bbox_head +def _get_pointrcnn_rpn_head_cfg(fname): + """Grab configs necessary to create a rpn_head. + + These are deep copied to allow for safe modification of parameters without + influencing other tests. + """ + config = _get_config_module(fname) + model = copy.deepcopy(config.model) + train_cfg = mmcv.Config(copy.deepcopy(config.model.train_cfg)) + test_cfg = mmcv.Config(copy.deepcopy(config.model.test_cfg)) + + rpn_head = model.rpn_head + rpn_head.update(train_cfg=train_cfg.rpn) + rpn_head.update(test_cfg=test_cfg.rpn) + return rpn_head, train_cfg.rpn.rpn_proposal + + def _get_vote_head_cfg(fname): """Grab configs necessary to create a vote_head. @@ -147,6 +165,14 @@ def _get_parta2_bbox_head_cfg(fname): return vote_head +def _get_pointrcnn_bbox_head_cfg(fname): + config = _get_config_module(fname) + model = copy.deepcopy(config.model) + + vote_head = model.roi_head.bbox_head + return vote_head + + def test_anchor3d_head_loss(): if not torch.cuda.is_available(): pytest.skip('test requires GPU and torch+cuda') @@ -263,6 +289,39 @@ def test_parta2_rpnhead_getboxes(): assert result_list[0]['boxes_3d'].tensor.shape == torch.Size([512, 7]) +def test_pointrcnn_rpnhead_getboxes(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + rpn_head_cfg, proposal_cfg = _get_pointrcnn_rpn_head_cfg( + './pointrcnn/pointrcnn_2x8_kitti-3d-3classes.py') + self = build_head(rpn_head_cfg) + self.cuda() + + fp_features = torch.rand([2, 128, 1024], dtype=torch.float32).cuda() + feats = {'fp_features': fp_features} + # fake input_metas + input_metas = [{ + 'sample_idx': 1234, + 'box_type_3d': LiDARInstance3DBoxes, + 'box_mode_3d': Box3DMode.LIDAR + }, { + 'sample_idx': 2345, + 'box_type_3d': LiDARInstance3DBoxes, + 'box_mode_3d': Box3DMode.LIDAR + }] + (bbox_preds, cls_preds) = self.forward(feats) + assert bbox_preds.shape == (2, 1024, 8) + assert cls_preds.shape == (2, 1024, 3) + points = torch.rand([2, 1024, 3], dtype=torch.float32).cuda() + result_list = self.get_bboxes(points, bbox_preds, cls_preds, input_metas) + max_num = proposal_cfg.max_num + bbox, score_selected, labels, cls_preds_selected = result_list[0] + assert bbox.tensor.shape == (max_num, 7) + assert score_selected.shape == (max_num, ) + assert labels.shape == (max_num, ) + assert cls_preds_selected.shape == (max_num, 3) + + def test_vote_head(): if not torch.cuda.is_available(): pytest.skip('test requires GPU and torch+cuda') @@ -466,6 +525,18 @@ def test_parta2_bbox_head(): assert bbox_pred.shape == (256, 7) +def test_pointrcnn_bbox_head(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + pointrcnn_bbox_head_cfg = _get_pointrcnn_bbox_head_cfg( + './pointrcnn/pointrcnn_2x8_kitti-3d-3classes.py') + self = build_head(pointrcnn_bbox_head_cfg).cuda() + feats = torch.rand([100, 512, 133]).cuda() + rcnn_cls, rcnn_reg = self.forward(feats) + assert rcnn_cls.shape == (100, 1) + assert rcnn_reg.shape == (100, 7) + + def test_part_aggregation_ROI_head(): if not torch.cuda.is_available(): pytest.skip('test requires GPU and torch+cuda') @@ -540,6 +611,50 @@ def test_part_aggregation_ROI_head(): assert labels_3d.shape == (12, ) +def test_pointrcnn_roi_head(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + + roi_head_cfg = _get_roi_head_cfg( + './pointrcnn/pointrcnn_2x8_kitti-3d-3classes.py') + + self = build_head(roi_head_cfg).cuda() + + features = torch.rand([3, 128, 16384]).cuda() + points = torch.rand([3, 16384, 3]).cuda() + points_cls_preds = torch.rand([3, 16384, 3]).cuda() + rcnn_feats = { + 'features': features, + 'points': points, + 'points_cls_preds': points_cls_preds + } + boxes_3d = LiDARInstance3DBoxes(torch.rand(50, 7).cuda()) + labels_3d = torch.randint(low=0, high=2, size=[50]).cuda() + proposal = {'boxes_3d': boxes_3d, 'labels_3d': labels_3d} + proposal_list = [proposal for i in range(3)] + gt_bboxes_3d = [ + LiDARInstance3DBoxes(torch.rand([5, 7], device='cuda')) + for i in range(3) + ] + gt_labels_3d = [torch.randint(0, 2, [5], device='cuda') for i in range(3)] + box_type_3d = LiDARInstance3DBoxes + img_metas = [dict(box_type_3d=box_type_3d) for i in range(3)] + + losses = self.forward_train(rcnn_feats, img_metas, proposal_list, + gt_bboxes_3d, gt_labels_3d) + assert losses['loss_cls'] >= 0 + assert losses['loss_bbox'] >= 0 + assert losses['loss_corner'] >= 0 + + bbox_results = self.simple_test(rcnn_feats, img_metas, proposal_list) + boxes_3d = bbox_results[0]['boxes_3d'] + scores_3d = bbox_results[0]['scores_3d'] + labels_3d = bbox_results[0]['labels_3d'] + assert boxes_3d.tensor.shape[1] == 7 + assert boxes_3d.tensor.shape[0] == scores_3d.shape[0] + assert scores_3d.shape[0] == labels_3d.shape[0] + + def test_free_anchor_3D_head(): if not torch.cuda.is_available(): pytest.skip('test requires GPU and torch+cuda') @@ -700,7 +815,7 @@ def test_h3d_head(): h3d_head_cfg.bbox_head.num_proposal = num_proposal self = build_head(h3d_head_cfg).cuda() - # prepare roi outputs + # prepare RoI outputs fp_xyz = [torch.rand([1, num_point, 3], dtype=torch.float32).cuda()] hd_features = torch.rand([1, 256, num_point], dtype=torch.float32).cuda() fp_indices = [torch.randint(0, 128, [1, num_point]).cuda()] diff --git a/tests/test_runtime/test_config.py b/tests/test_runtime/test_config.py index a1950e283b..650b46df65 100644 --- a/tests/test_runtime/test_config.py +++ b/tests/test_runtime/test_config.py @@ -61,6 +61,8 @@ def test_config_build_model(): check_parta2_roi_head(head_config, detector.roi_head) elif head_config.type == 'H3DRoIHead': check_h3d_roi_head(head_config, detector.roi_head) + elif head_config.type == 'PointRCNNRoIHead': + check_pointrcnn_roi_head(head_config, detector.roi_head) else: _check_roi_head(head_config, detector.roi_head) # else: @@ -273,3 +275,28 @@ def _check_h3d_bbox_head(bbox_cfg, bbox_head): 12 == bbox_head.line_center_matcher.num_point[0] assert bbox_cfg.suface_matching_cfg.mlp_channels[-1] * \ 18 == bbox_head.bbox_pred[0].in_channels + + +def check_pointrcnn_roi_head(config, head): + assert config['type'] == head.__class__.__name__ + + # check point_roi_extractor + point_roi_cfg = config.point_roi_extractor + point_roi_extractor = head.point_roi_extractor + _check_pointrcnn_roi_extractor(point_roi_cfg, point_roi_extractor) + # check pointrcnn rcnn bboxhead + bbox_cfg = config.bbox_head + bbox_head = head.bbox_head + _check_pointrcnn_bbox_head(bbox_cfg, bbox_head) + + +def _check_pointrcnn_roi_extractor(config, roi_extractor): + assert config['type'] == roi_extractor.__class__.__name__ + assert config.roi_layer.num_sampled_points == \ + roi_extractor.roi_layer.num_sampled_points + + +def _check_pointrcnn_bbox_head(bbox_cfg, bbox_head): + assert bbox_cfg['type'] == bbox_head.__class__.__name__ + assert bbox_cfg.num_classes == bbox_head.num_classes + assert bbox_cfg.with_corner_loss == bbox_head.with_corner_loss From 9b475a84c4de8978ffd66fbc571ac407e6f1ae1a Mon Sep 17 00:00:00 2001 From: Wenhao Wu Date: Wed, 1 Dec 2021 21:49:58 +0800 Subject: [PATCH 2/7] rename config & model --- .../point_rcnn_2x8_kitti-3d-3classes.py} | 0 mmdet3d/models/detectors/{pointrcnn.py => point_rcnn.py} | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename configs/{pointrcnn/pointrcnn_2x8_kitti-3d-3classes.py => point_rcnn/point_rcnn_2x8_kitti-3d-3classes.py} (100%) rename mmdet3d/models/detectors/{pointrcnn.py => point_rcnn.py} (100%) diff --git a/configs/pointrcnn/pointrcnn_2x8_kitti-3d-3classes.py b/configs/point_rcnn/point_rcnn_2x8_kitti-3d-3classes.py similarity index 100% rename from configs/pointrcnn/pointrcnn_2x8_kitti-3d-3classes.py rename to configs/point_rcnn/point_rcnn_2x8_kitti-3d-3classes.py diff --git a/mmdet3d/models/detectors/pointrcnn.py b/mmdet3d/models/detectors/point_rcnn.py similarity index 100% rename from mmdet3d/models/detectors/pointrcnn.py rename to mmdet3d/models/detectors/point_rcnn.py From abc5f7197ccdf829b9502899459ee8979b8d7a0e Mon Sep 17 00:00:00 2001 From: Wenhao Wu Date: Thu, 2 Dec 2021 00:18:36 +0800 Subject: [PATCH 3/7] fix unittest --- mmdet3d/models/detectors/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mmdet3d/models/detectors/__init__.py b/mmdet3d/models/detectors/__init__.py index b53b19f2f2..894d7f33cd 100644 --- a/mmdet3d/models/detectors/__init__.py +++ b/mmdet3d/models/detectors/__init__.py @@ -10,7 +10,7 @@ from .mvx_faster_rcnn import DynamicMVXFasterRCNN, MVXFasterRCNN from .mvx_two_stage import MVXTwoStageDetector from .parta2 import PartA2 -from .pointrcnn import PointRCNN +from .point_rcnn import PointRCNN from .single_stage_mono3d import SingleStageMono3DDetector from .smoke_mono3d import SMOKEMono3D from .ssd3dnet import SSD3DNet From 2e5f90d1c3302d0ddae30754434e0695ce0fa76e Mon Sep 17 00:00:00 2001 From: Wenhao Wu Date: Thu, 2 Dec 2021 15:57:35 +0800 Subject: [PATCH 4/7] resolve comments & add docstring for class_agnostic_nms --- mmdet3d/models/dense_heads/point_rpn_head.py | 30 +++++++++++++------- mmdet3d/models/dense_heads/ssd_3d_head.py | 10 +++---- mmdet3d/models/detectors/point_rcnn.py | 5 +++- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/mmdet3d/models/dense_heads/point_rpn_head.py b/mmdet3d/models/dense_heads/point_rpn_head.py index ac2548105b..203255e613 100644 --- a/mmdet3d/models/dense_heads/point_rpn_head.py +++ b/mmdet3d/models/dense_heads/point_rpn_head.py @@ -102,14 +102,14 @@ def _get_reg_out_channels(self): def forward(self, feat_dict): point_features = feat_dict['fp_features'] point_features = point_features.permute(0, 2, 1).contiguous() - bs = point_features.shape[0] - x_cls = point_features.view(-1, point_features.shape[-1]) - x_reg = point_features.view(-1, point_features.shape[-1]) - - point_cls_preds = self.cls_layers(x_cls).reshape( - bs, -1, self._get_cls_out_channels()) - point_box_preds = self.reg_layers(x_reg).reshape( - bs, -1, self._get_reg_out_channels()) + batch_size = point_features.shape[0] + feat_cls = point_features.view(-1, point_features.shape[-1]) + feat_reg = point_features.view(-1, point_features.shape[-1]) + + point_cls_preds = self.cls_layers(feat_cls).reshape( + batch_size, -1, self._get_cls_out_channels()) + point_box_preds = self.reg_layers(feat_reg).reshape( + batch_size, -1, self._get_reg_out_channels()) return (point_box_preds, point_cls_preds) @force_fp32(apply_to=('bbox_preds')) @@ -271,8 +271,18 @@ def get_bboxes(self, return results def class_agnostic_nms(self, obj_scores, sem_scores, bbox): + """Class agnostic nms. + + Args: + obj_scores (torch.Tensor): Objectness score of bounding boxes. + sem_scores (torch.Tensor): Semantic class score of bounding boxes. + bbox (torch.Tensor): Predicted bounding boxes. + + Returns: + tuple[torch.Tensor]: Bounding boxes, scores and labels. + """ nms_cfg = self.test_cfg.nms_cfg if not self.training \ - else self.train_cfg.nms_cfg + else self.train_cfg.nms_cfg if nms_cfg.use_rotate_nms: nms_func = nms_gpu else: @@ -306,7 +316,7 @@ def _assign_targets_by_points_inside(self, bboxes_3d, points): """Compute assignment by checking whether point is inside bbox. Args: - bboxes_3d (BaseInstance3DBoxes): Instance of bounding boxes. + bboxes_3d (:obj:`BaseInstance3DBoxes`): Instance of bounding boxes. points (torch.Tensor): Points of a batch. Returns: diff --git a/mmdet3d/models/dense_heads/ssd_3d_head.py b/mmdet3d/models/dense_heads/ssd_3d_head.py index 6ab827d5c3..85c60a7e1d 100644 --- a/mmdet3d/models/dense_heads/ssd_3d_head.py +++ b/mmdet3d/models/dense_heads/ssd_3d_head.py @@ -479,7 +479,7 @@ def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points, Args: obj_scores (torch.Tensor): Objectness score of bounding boxes. - sem_scores (torch.Tensor): semantic class score of bounding boxes. + sem_scores (torch.Tensor): Semantic class score of bounding boxes. bbox (torch.Tensor): Predicted bounding boxes. points (torch.Tensor): Input points. input_meta (dict): Point cloud and image's meta info. @@ -505,20 +505,20 @@ def multiclass_nms_single(self, obj_scores, sem_scores, bbox, points, minmax_box3d[:, 3:] = torch.max(corner3d, dim=1)[0] bbox_classes = torch.argmax(sem_scores, -1) - nms_selected = batched_nms( + nms_keep = batched_nms( minmax_box3d[nonempty_box_mask][:, [0, 1, 3, 4]], obj_scores[nonempty_box_mask], bbox_classes[nonempty_box_mask], self.test_cfg.nms_cfg)[1] - if nms_selected.shape[0] > self.test_cfg.max_output_num: - nms_selected = nms_selected[:self.test_cfg.max_output_num] + if nms_keep.shape[0] > self.test_cfg.max_output_num: + nms_keep = nms_keep[:self.test_cfg.max_output_num] # filter empty boxes and boxes with low score scores_mask = (obj_scores >= self.test_cfg.score_thr) nonempty_box_inds = torch.nonzero( nonempty_box_mask, as_tuple=False).flatten() nonempty_mask = torch.zeros_like(bbox_classes).scatter( - 0, nonempty_box_inds[nms_selected], 1) + 0, nonempty_box_inds[nms_keep], 1) selected = (nonempty_mask.bool() & scores_mask.bool()) if self.test_cfg.per_class_proposal: diff --git a/mmdet3d/models/detectors/point_rcnn.py b/mmdet3d/models/detectors/point_rcnn.py index 93b6b37974..6f2238a402 100644 --- a/mmdet3d/models/detectors/point_rcnn.py +++ b/mmdet3d/models/detectors/point_rcnn.py @@ -8,7 +8,7 @@ class PointRCNN(TwoStage3DDetector): r"""PointRCNN detector. - Please refer to the `PointRCNN `_ + Please refer to the `PointRCNN https://arxiv.org/abs/1812.04244`_ Args: backbone (dict): Config dict of detector's backbone. @@ -17,6 +17,7 @@ class PointRCNN(TwoStage3DDetector): roi_head (dict, optional): Config of ROI head. Defaults to None. train_cfg (dict, optional): Train configs. Defaults to None. test_cfg (dict, optional): Test configs. Defaults to None. + pretrained (str, optional): Model pretrained path. Defaults to None. init_cfg (dict, optional): Config of initialization. Defaults to None. """ @@ -111,6 +112,8 @@ def simple_test(self, points, img_metas, imgs=None, rescale=False): Args: points (list[torch.Tensor]): Points of each sample. img_metas (list[dict]): Image metas. + imgs (list[torch.Tensor], optional): Images of each sample. + Defaults to None. rescale (bool, optional): Whether to rescale results. Defaults to False. From 09134a2d241ba0182d3c0c61d0f4261d2825a653 Mon Sep 17 00:00:00 2001 From: Wenhao Wu Date: Wed, 8 Dec 2021 15:01:42 +0800 Subject: [PATCH 5/7] refine loss calculation & remove find_unused_parameters --- configs/_base_/models/point_rcnn.py | 2 - .../bbox_heads/point_rcnn_bbox_head.py | 74 +++++++++---------- 2 files changed, 35 insertions(+), 41 deletions(-) diff --git a/configs/_base_/models/point_rcnn.py b/configs/_base_/models/point_rcnn.py index e32c6f5084..fb2e03c34a 100644 --- a/configs/_base_/models/point_rcnn.py +++ b/configs/_base_/models/point_rcnn.py @@ -126,5 +126,3 @@ use_rotate_nms=True, iou_thr=0.85, nms_pre=9000, nms_post=512), score_thr=None), rcnn=dict(use_rotate_nms=True, nms_thr=0.1, score_thr=0.1))) - -find_unused_parameters = True diff --git a/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py index 2e1c2d5a49..77ccb25823 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py +++ b/mmdet3d/models/roi_heads/bbox_heads/point_rcnn_bbox_head.py @@ -271,46 +271,42 @@ def loss(self, cls_score, bbox_pred, rois, labels, bbox_targets, # calculate regression loss code_size = self.bbox_coder.code_size pos_inds = (reg_mask > 0) - if pos_inds.any() == 0: - # fake a part loss - losses['loss_bbox'] = loss_cls.new_tensor(0) - if self.with_corner_loss: - losses['loss_corner'] = loss_cls.new_tensor(0) + + pos_bbox_pred = bbox_pred.view(rcnn_batch_size, -1)[pos_inds].clone() + bbox_weights_flat = bbox_weights[pos_inds].view(-1, 1).repeat( + 1, pos_bbox_pred.shape[-1]) + loss_bbox = self.loss_bbox( + pos_bbox_pred.unsqueeze(dim=0), + bbox_targets.unsqueeze(dim=0).detach(), + bbox_weights_flat.unsqueeze(dim=0)) + losses['loss_bbox'] = loss_bbox + + if pos_inds.any() != 0 and self.with_corner_loss: + rois = rois.detach() + pos_roi_boxes3d = rois[..., 1:].view(-1, code_size)[pos_inds] + pos_roi_boxes3d = pos_roi_boxes3d.view(-1, code_size) + batch_anchors = pos_roi_boxes3d.clone().detach() + pos_rois_rotation = pos_roi_boxes3d[..., 6].view(-1) + roi_xyz = pos_roi_boxes3d[..., 0:3].view(-1, 3) + batch_anchors[..., 0:3] = 0 + # decode boxes + pred_boxes3d = self.bbox_coder.decode( + batch_anchors, + pos_bbox_pred.view(-1, code_size)).view(-1, code_size) + + pred_boxes3d[..., 0:3] = rotation_3d_in_axis( + pred_boxes3d[..., 0:3].unsqueeze(1), (pos_rois_rotation), + axis=2).squeeze(1) + + pred_boxes3d[:, 0:3] += roi_xyz + + # calculate corner loss + loss_corner = self.get_corner_loss_lidar(pred_boxes3d, + pos_gt_bboxes) + + losses['loss_corner'] = loss_corner else: - pos_bbox_pred = bbox_pred.view(rcnn_batch_size, - -1)[pos_inds].clone() - bbox_weights_flat = bbox_weights[pos_inds].view(-1, 1).repeat( - 1, pos_bbox_pred.shape[-1]) - loss_bbox = self.loss_bbox( - pos_bbox_pred.unsqueeze(dim=0), - bbox_targets.unsqueeze(dim=0).detach(), - bbox_weights_flat.unsqueeze(dim=0)) - losses['loss_bbox'] = loss_bbox - - if self.with_corner_loss: - rois = rois.detach() - pos_roi_boxes3d = rois[..., 1:].view(-1, code_size)[pos_inds] - pos_roi_boxes3d = pos_roi_boxes3d.view(-1, code_size) - batch_anchors = pos_roi_boxes3d.clone().detach() - pos_rois_rotation = pos_roi_boxes3d[..., 6].view(-1) - roi_xyz = pos_roi_boxes3d[..., 0:3].view(-1, 3) - batch_anchors[..., 0:3] = 0 - # decode boxes - pred_boxes3d = self.bbox_coder.decode( - batch_anchors, - pos_bbox_pred.view(-1, code_size)).view(-1, code_size) - - pred_boxes3d[..., 0:3] = rotation_3d_in_axis( - pred_boxes3d[..., 0:3].unsqueeze(1), (pos_rois_rotation), - axis=2).squeeze(1) - - pred_boxes3d[:, 0:3] += roi_xyz - - # calculate corner loss - loss_corner = self.get_corner_loss_lidar( - pred_boxes3d, pos_gt_bboxes) - - losses['loss_corner'] = loss_corner + losses['loss_corner'] = loss_cls.new_tensor(0) return losses From bbd545c76429b485f584bf9a4b95585c4153ed5e Mon Sep 17 00:00:00 2001 From: Wenhao Wu Date: Wed, 8 Dec 2021 15:33:09 +0800 Subject: [PATCH 6/7] resolve typo & add docstring --- configs/_base_/models/point_rcnn.py | 2 +- mmdet3d/models/dense_heads/point_rpn_head.py | 10 ++++++++++ mmdet3d/models/detectors/point_rcnn.py | 2 +- 3 files changed, 12 insertions(+), 2 deletions(-) diff --git a/configs/_base_/models/point_rcnn.py b/configs/_base_/models/point_rcnn.py index fb2e03c34a..7a9667e733 100644 --- a/configs/_base_/models/point_rcnn.py +++ b/configs/_base_/models/point_rcnn.py @@ -92,7 +92,7 @@ neg_iou_thr=0.55, min_pos_iou=0.55, ignore_iof_thr=-1), - dict( # for Pestrian + dict( # for Pedestrian type='MaxIoUAssigner', iou_calculator=dict( type='BboxOverlaps3D', coordinate='lidar'), diff --git a/mmdet3d/models/dense_heads/point_rpn_head.py b/mmdet3d/models/dense_heads/point_rpn_head.py index 203255e613..a0cc454462 100644 --- a/mmdet3d/models/dense_heads/point_rpn_head.py +++ b/mmdet3d/models/dense_heads/point_rpn_head.py @@ -100,6 +100,16 @@ def _get_reg_out_channels(self): return self.bbox_coder.code_size def forward(self, feat_dict): + """Forward pass. + + Args: + feat_dict (dict): Feature dict from backbone. + + Returns: + tuple: + point_box_preds (list[Tensor]): Predicted Boxes. + point_cls_preds (list[Tensor]): Predicted Boxes scores. + """ point_features = feat_dict['fp_features'] point_features = point_features.permute(0, 2, 1).contiguous() batch_size = point_features.shape[0] diff --git a/mmdet3d/models/detectors/point_rcnn.py b/mmdet3d/models/detectors/point_rcnn.py index 6f2238a402..94b40afd7a 100644 --- a/mmdet3d/models/detectors/point_rcnn.py +++ b/mmdet3d/models/detectors/point_rcnn.py @@ -8,7 +8,7 @@ class PointRCNN(TwoStage3DDetector): r"""PointRCNN detector. - Please refer to the `PointRCNN https://arxiv.org/abs/1812.04244`_ + Please refer to the `PointRCNN `_ Args: backbone (dict): Config dict of detector's backbone. From 6781b91d528a4a0435ca960b901e935a6ccd497a Mon Sep 17 00:00:00 2001 From: Wenhao Wu Date: Wed, 8 Dec 2021 16:19:57 +0800 Subject: [PATCH 7/7] resolve comments --- mmdet3d/models/dense_heads/point_rpn_head.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/mmdet3d/models/dense_heads/point_rpn_head.py b/mmdet3d/models/dense_heads/point_rpn_head.py index a0cc454462..ef1ec704e1 100644 --- a/mmdet3d/models/dense_heads/point_rpn_head.py +++ b/mmdet3d/models/dense_heads/point_rpn_head.py @@ -106,9 +106,8 @@ def forward(self, feat_dict): feat_dict (dict): Feature dict from backbone. Returns: - tuple: - point_box_preds (list[Tensor]): Predicted Boxes. - point_cls_preds (list[Tensor]): Predicted Boxes scores. + tuple[list[torch.Tensor]]: Predicted boxes and classification + scores. """ point_features = feat_dict['fp_features'] point_features = point_features.permute(0, 2, 1).contiguous()