Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Refactor Point RCNN and fix it. #1819

Merged
merged 20 commits into from
Sep 30, 2022
45 changes: 31 additions & 14 deletions configs/_base_/models/point_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
model = dict(
type='PointRCNN',
data_preprocessor=dict(type='Det3DDataPreprocessor'),
backbone=dict(
type='PointNet2SAMSG',
in_channels=4,
Expand Down Expand Up @@ -34,14 +35,14 @@
cls_linear_channels=(256, 256),
reg_linear_channels=(256, 256)),
cls_loss=dict(
type='FocalLoss',
type='mmdet.FocalLoss',
use_sigmoid=True,
reduction='sum',
gamma=2.0,
alpha=0.25,
loss_weight=1.0),
bbox_loss=dict(
type='SmoothL1Loss',
type='mmdet.SmoothL1Loss',
beta=1.0 / 9.0,
reduction='sum',
loss_weight=1.0),
Expand All @@ -55,12 +56,22 @@
1.73]])),
roi_head=dict(
type='PointRCNNRoIHead',
point_roi_extractor=dict(
bbox_roi_extractor=dict(
type='Single3DRoIPointExtractor',
roi_layer=dict(type='RoIPointPool3d', num_sampled_points=512)),
bbox_head=dict(
type='PointRCNNBboxHead',
num_classes=1,
loss_bbox=dict(
type='mmdet.SmoothL1Loss',
beta=1.0 / 9.0,
reduction='sum',
loss_weight=1.0),
loss_cls=dict(
type='mmdet.CrossEntropyLoss',
use_sigmoid=True,
reduction='sum',
loss_weight=1.0),
pred_layer_cfg=dict(
in_channels=512,
cls_conv_channels=(256, 256),
Expand All @@ -79,31 +90,34 @@
train_cfg=dict(
pos_distance_thr=10.0,
rpn=dict(
nms_cfg=dict(
use_rotate_nms=True, iou_thr=0.8, nms_pre=9000, nms_post=512),
score_thr=None),
rpn_proposal=dict(
use_rotate_nms=True,
score_thr=None,
iou_thr=0.8,
nms_pre=9000,
nms_post=512)),
rcnn=dict(
assigner=[
dict( # for Car
type='MaxIoUAssigner',
dict( # for Pedestrian
type='Max3DIoUAssigner',
iou_calculator=dict(
type='BboxOverlaps3D', coordinate='lidar'),
pos_iou_thr=0.55,
neg_iou_thr=0.55,
min_pos_iou=0.55,
ignore_iof_thr=-1,
match_low_quality=False),
dict( # for Pedestrian
type='MaxIoUAssigner',
dict( # for Cyclist
type='Max3DIoUAssigner',
iou_calculator=dict(
type='BboxOverlaps3D', coordinate='lidar'),
pos_iou_thr=0.55,
neg_iou_thr=0.55,
min_pos_iou=0.55,
ignore_iof_thr=-1,
match_low_quality=False),
dict( # for Cyclist
type='MaxIoUAssigner',
dict( # for Car
type='Max3DIoUAssigner',
iou_calculator=dict(
type='BboxOverlaps3D', coordinate='lidar'),
pos_iou_thr=0.55,
Expand All @@ -126,6 +140,9 @@
test_cfg=dict(
rpn=dict(
nms_cfg=dict(
use_rotate_nms=True, iou_thr=0.85, nms_pre=9000, nms_post=512),
score_thr=None),
use_rotate_nms=True,
iou_thr=0.85,
nms_pre=9000,
nms_post=512,
score_thr=None)),
rcnn=dict(use_rotate_nms=True, nms_thr=0.1, score_thr=0.1)))
90 changes: 59 additions & 31 deletions configs/point_rcnn/point-rcnn_8xb2_kitti-3d-3class.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
# dataset settings
dataset_type = 'KittiDataset'
data_root = 'data/kitti/'
class_names = ['Car', 'Pedestrian', 'Cyclist']
class_names = ['Pedestrian', 'Cyclist', 'Car']
metainfo = dict(CLASSES=class_names)
point_cloud_range = [0, -40, -3, 70.4, 40, 1]
input_modality = dict(use_lidar=True, use_camera=False)

Expand Down Expand Up @@ -42,8 +43,9 @@
dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointSample', num_points=16384, sample_range=40.0),
dict(type='PointShuffle'),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
dict(
type='Pack3DDetInputs',
keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
test_pipeline = [
dict(type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=4),
Expand All @@ -61,41 +63,67 @@
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='PointSample', num_points=16384, sample_range=40.0),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
with_label=False),
dict(type='Collect3D', keys=['points'])
])
dict(type='PointSample', num_points=16384, sample_range=40.0)
]),
dict(type='Pack3DDetInputs', keys=['points'])
]

data = dict(
samples_per_gpu=2,
workers_per_gpu=2,
train=dict(
train_dataloader = dict(
batch_size=2,
num_workers=2,
dataset=dict(
type='RepeatDataset',
times=2,
dataset=dict(pipeline=train_pipeline, classes=class_names)),
val=dict(pipeline=test_pipeline, classes=class_names),
test=dict(pipeline=test_pipeline, classes=class_names))
dataset=dict(pipeline=train_pipeline, metainfo=metainfo)))
test_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo))
val_dataloader = dict(dataset=dict(pipeline=test_pipeline, metainfo=metainfo))

# optimizer
lr = 0.001 # max learning rate
optimizer = dict(lr=lr, betas=(0.95, 0.85))
# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=80)
evaluation = dict(interval=2)
# yapf:disable
log_config = dict(
interval=30,
hooks=[
dict(type='TextLoggerHook'),
dict(type='TensorboardLoggerHook')
])
# yapf:enable
optim_wrapper = dict(optimizer=dict(lr=lr, betas=(0.95, 0.85)))
train_cfg = dict(by_epoch=True, max_epochs=80, val_interval=2)

# Default setting for scaling LR automatically
# - `enable` means enable scaling LR automatically
# or not by default.
# - `base_batch_size` = (8 GPUs) x (2 samples per GPU).
auto_scale_lr = dict(enable=False, base_batch_size=16)
param_scheduler = [
# learning rate scheduler
# During the first 35 epochs, learning rate increases from 0 to lr * 10
# during the next 45 epochs, learning rate decreases from lr * 10 to
# lr * 1e-4
dict(
type='CosineAnnealingLR',
T_max=35,
eta_min=lr * 10,
begin=0,
end=35,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingLR',
T_max=45,
eta_min=lr * 1e-4,
begin=35,
end=80,
by_epoch=True,
convert_to_iter_based=True),
# momentum scheduler
# During the first 35 epochs, momentum increases from 0 to 0.85 / 0.95
# during the next 45 epochs, momentum increases from 0.85 / 0.95 to 1
dict(
type='CosineAnnealingMomentum',
T_max=35,
eta_min=0.85 / 0.95,
begin=0,
end=35,
by_epoch=True,
convert_to_iter_based=True),
dict(
type='CosineAnnealingMomentum',
T_max=45,
eta_min=1,
begin=35,
end=80,
by_epoch=True,
convert_to_iter_based=True)
]
6 changes: 3 additions & 3 deletions docs/en/advanced_guides/customize_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -365,7 +365,7 @@ class PartAggregationROIHead(Base3DRoIHead):

Args:
feats_dict (dict): Contains features from the first stage.
rpn_results_list (List[:obj:`InstancesData`]): Detection results
rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
Expand Down Expand Up @@ -412,7 +412,7 @@ class PartAggregationROIHead(Base3DRoIHead):
voxel_dict (dict): Contains information of voxels.
batch_input_metas (list[dict], Optional): Batch image meta info.
Defaults to None.
rpn_results_list (List[:obj:`InstancesData`]): Detection results
rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head.
test_cfg (Config): Test config.

Expand All @@ -438,7 +438,7 @@ class PartAggregationROIHead(Base3DRoIHead):

Args:
feats_dict (dict): Contains features from the first stage.
rpn_results_list (List[:obj:`InstancesData`]): Detection results
rpn_results_list (List[:obj:`InstanceData`]): Detection results
of rpn head.
batch_data_samples (List[:obj:`Det3DDataSample`]): The Data
samples. It usually includes information such as
Expand Down
2 changes: 1 addition & 1 deletion mmdet3d/models/dense_heads/base_3d_dense_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ def predict_by_feat(self,
score_factors (list[Tensor], optional): Score factor for
all scale level, each is a 4D-tensor, has shape
(batch_size, num_priors * 1, H, W). Defaults to None.
batch_input_metas (list[dict], Optional): Batch image meta info.
batch_input_metas (list[dict], Optional): Batch inputs meta info.
Defaults to None.
cfg (ConfigDict, optional): Test / postprocessing
configuration, if None, test_cfg would be used.
Expand Down
10 changes: 4 additions & 6 deletions mmdet3d/models/dense_heads/parta2_rpn_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,7 @@ def _predict_by_feat_single(self,
result = self.class_agnostic_nms(mlvl_bboxes, mlvl_bboxes_for_nms,
mlvl_max_scores, mlvl_label_pred,
mlvl_cls_score, mlvl_dir_scores,
score_thr, cfg.nms_post, cfg,
input_meta)
score_thr, cfg, input_meta)
return result

def loss_and_predict(self,
Expand Down Expand Up @@ -275,7 +274,7 @@ def class_agnostic_nms(self, mlvl_bboxes: Tensor,
mlvl_bboxes_for_nms: Tensor,
mlvl_max_scores: Tensor, mlvl_label_pred: Tensor,
mlvl_cls_score: Tensor, mlvl_dir_scores: Tensor,
score_thr: int, max_num: int, cfg: ConfigDict,
score_thr: int, cfg: ConfigDict,
input_meta: dict) -> Dict:
"""Class agnostic nms for single batch.

Expand All @@ -291,7 +290,6 @@ def class_agnostic_nms(self, mlvl_bboxes: Tensor,
mlvl_dir_scores (torch.Tensor): Direction scores of
Multi-level bbox.
score_thr (int): Score threshold.
max_num (int): Max number of bboxes after nms.
cfg (:obj:`ConfigDict`): Training or testing config.
input_meta (dict): Contain pcd and img's meta info.

Expand Down Expand Up @@ -339,9 +337,9 @@ def class_agnostic_nms(self, mlvl_bboxes: Tensor,
scores = torch.cat(scores, dim=0)
cls_scores = torch.cat(cls_scores, dim=0)
labels = torch.cat(labels, dim=0)
if bboxes.shape[0] > max_num:
if bboxes.shape[0] > cfg.nms_post:
_, inds = scores.sort(descending=True)
inds = inds[:max_num]
inds = inds[:cfg.nms_post]
bboxes = bboxes[inds, :]
labels = labels[inds]
scores = scores[inds]
Expand Down
Loading