Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support CenterFormer in projects #2173

Merged
merged 18 commits into from
Jan 5, 2023
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions mmdet3d/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,10 @@ def init_model(config: Union[str, Path, Config],

if checkpoint is not None:
checkpoint = load_checkpoint(model, checkpoint, map_location='cpu')
dataset_meta = checkpoint['meta'].get('dataset_meta', None)
# save the dataset_meta in the model for convenience
if 'dataset_meta' in checkpoint.get('meta', {}):
# mmdet3d 1.x
model.dataset_meta = dataset_meta
model.dataset_meta = checkpoint['meta']['dataset_meta']
elif 'CLASSES' in checkpoint.get('meta', {}):
# < mmdet3d 1.x
classes = checkpoint['meta']['CLASSES']
Expand Down
8 changes: 8 additions & 0 deletions mmdet3d/datasets/transforms/loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,8 @@ class LoadPointsFromFile(BaseTransform):
or use_dim=[0, 1, 2, 3] to use the intensity dimension.
shift_height (bool): Whether to use shifted height. Defaults to False.
use_color (bool): Whether to use color features. Defaults to False.
norm_intensity (bool): Whether to normlize the intensity. Defaults to
False.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmengine.fileio.FileClient` for details.
Defaults to dict(backend='disk').
Expand All @@ -544,6 +546,7 @@ def __init__(
use_dim: Union[int, List[int]] = [0, 1, 2],
shift_height: bool = False,
use_color: bool = False,
norm_intensity: bool = False,
file_client_args: dict = dict(backend='disk')
) -> None:
self.shift_height = shift_height
Expand All @@ -557,6 +560,7 @@ def __init__(
self.coord_type = coord_type
self.load_dim = load_dim
self.use_dim = use_dim
self.norm_intensity = norm_intensity
self.file_client_args = file_client_args.copy()
self.file_client = None

Expand Down Expand Up @@ -599,6 +603,10 @@ def transform(self, results: dict) -> dict:
points = self._load_points(pts_file_path)
points = points.reshape(-1, self.load_dim)
points = points[:, self.use_dim]
if self.norm_intensity:
assert len(self.use_dim) >= 4, \
f'When using intensity norm, expect used dimensions >= 4, got {len(self.use_dim)}' # noqa: E501
points[:, 3] = np.tanh(points[:, 3])
attribute_dims = None

if self.shift_height:
Expand Down
4 changes: 4 additions & 0 deletions mmdet3d/datasets/transforms/transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,7 @@ def __init__(self,
db_sampler['type'] = 'DataBaseSampler'
self.db_sampler = TRANSFORMS.build(db_sampler)
self.use_ground_plane = use_ground_plane
self.disabled = False

@staticmethod
def remove_points_in_boxes(points: BasePoints,
Expand Down Expand Up @@ -387,6 +388,9 @@ def transform(self, input_dict: dict) -> dict:
'points', 'gt_bboxes_3d', 'gt_labels_3d' keys are updated
in the result dict.
"""
if self.disabled:
return input_dict

gt_bboxes_3d = input_dict['gt_bboxes_3d']
gt_labels_3d = input_dict['gt_labels_3d']

Expand Down
5 changes: 4 additions & 1 deletion mmdet3d/engine/hooks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .benchmark_hook import BenchmarkHook
from .disable_object_sample_hook import DisableObjectSampleHook
from .visualization_hook import Det3DVisualizationHook

__all__ = ['Det3DVisualizationHook', 'BenchmarkHook']
__all__ = [
'Det3DVisualizationHook', 'BenchmarkHook', 'DisableObjectSampleHook'
]
54 changes: 54 additions & 0 deletions mmdet3d/engine/hooks/disable_object_sample_hook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.runner import Runner

from mmdet3d.datasets.transforms import ObjectSample
from mmdet3d.registry import HOOKS


@HOOKS.register_module()
class DisableObjectSampleHook(Hook):
"""The hook of disabling augmentations during training.

Args:
disable_after_epoch (int): The number of epochs after which
the ``ObjectSample`` will be closed in the training.
Defaults to 15.
"""

def __init__(self, disable_after_epoch: int = 15):
self.disable_after_epoch = disable_after_epoch
self._restart_dataloader = False

def before_train_epoch(self, runner: Runner):
"""Close augmentation.

Args:
runner (Runner): The runner.
"""
epoch = runner.epoch
train_loader = runner.train_dataloader
model = runner.model
# TODO: refactor after mmengine using model wrapper
if is_model_wrapper(model):
model = model.module
if epoch == self.disable_after_epoch:
runner.logger.info('Disable ObjectSample')
for transform in runner.train_dataloader.dataset.pipeline.transforms: # noqa: E501
if isinstance(transform, ObjectSample):
assert hasattr(transform, 'disabled')
transform.disabled = True
# The dataset pipeline cannot be updated when persistent_workers
# is True, so we need to force the dataloader's multi-process
# restart. This is a very hacky approach.
if hasattr(train_loader, 'persistent_workers'
) and train_loader.persistent_workers is True:
train_loader._DataLoader__initialized = False
train_loader._iterator = None
self._restart_dataloader = True
else:
# Once the restart is complete, we need to restore
# the initialization flag.
if self._restart_dataloader:
train_loader._DataLoader__initialized = True
4 changes: 1 addition & 3 deletions mmdet3d/models/layers/spconv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,7 @@
except ImportError:
IS_SPCONV2_AVAILABLE = False
else:
if hasattr(spconv,
'__version__') and spconv.__version__ >= '2.0.0' and hasattr(
spconv, 'pytorch'):
if hasattr(spconv, '__version__') and spconv.__version__ >= '2.0.0':
IS_SPCONV2_AVAILABLE = register_spconv2()
else:
IS_SPCONV2_AVAILABLE = False
Expand Down
99 changes: 99 additions & 0 deletions projects/centerformer/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# CenterFormer: Center-based Transformer for 3D Object Detection
JingweiZhang12 marked this conversation as resolved.
Show resolved Hide resolved

> [CenterFormer: Center-based Transformer for 3D Object Detection](https://arxiv.org/abs/2209.05588)

<!-- [ALGORITHM] -->

## Abstract

Query-based transformer has shown great potential in con-
structing long-range attention in many image-domain tasks, but has
rarely been considered in LiDAR-based 3D object detection due to the
overwhelming size of the point cloud data. In this paper, we propose
CenterFormer, a center-based transformer network for 3D object de-
tection. CenterFormer first uses a center heatmap to select center candi-
dates on top of a standard voxel-based point cloud encoder. It then uses
the feature of the center candidate as the query embedding in the trans-
former. To further aggregate features from multiple frames, we design
an approach to fuse features through cross-attention. Lastly, regression
heads are added to predict the bounding box on the output center feature
representation. Our design reduces the convergence difficulty and compu-
tational complexity of the transformer structure. The results show signif-
icant improvements over the strong baseline of anchor-free object detec-
tion networks. CenterFormer achieves state-of-the-art performance for a
single model on the Waymo Open Dataset, with 73.7% mAPH on the val-
idation set and 75.6% mAPH on the test set, significantly outperforming
all previously published CNN and transformer-based methods. Our code
is publicly available at https://github.com/TuSimple/centerformer

<div align=center>
<img src="https://user-images.githubusercontent.com/34888372/209500088-b707d7cd-d4d5-4f20-8fdf-a2c7ad15df34.png" width="800"/>
</div>

## Introduction

We implement CenterFormer and provide the result and checkpoints on Waymo dataset.

We follow the below style to name config files. Contributors are advised to follow the same style.
`{xxx}` is required field and `[yyy]` is optional.

`{model}`: model type like `centerpoint`.

`{model setting}`: voxel size and voxel type like `01voxel`, `02pillar`.

`{backbone}`: backbone type like `second`.

`{neck}`: neck type like `secfpn`.

`[batch_per_gpu x gpu]`: GPUs and samples per GPU, 4x8 is used by default.

`{schedule}`: training schedule, options are 1x, 2x, 20e, etc. 1x and 2x means 12 epochs and 24 epochs respectively. 20e is adopted in cascade models, which denotes 20 epochs. For 1x/2x, initial learning rate decays by a factor of 10 at the 8/16th and 11/22th epochs. For 20e, initial learning rate decays by a factor of 10 at the 16th and 19th epochs.

`{dataset}`: dataset like nus-3d, kitti-3d, lyft-3d, scannet-3d, sunrgbd-3d. We also indicate the number of classes we are using if there exist multiple settings, e.g., kitti-3d-3class and kitti-3d-car means training on KITTI dataset with 3 classes and single class, respectively.

## Usage

<!-- For a typical model, this section should contain the commands for training and testing. You are also suggested to dump your environment specification to env.yml by `conda env export > env.yml`. -->

### Training commands

In MMDetection3D's root directory, run the following command to train the model:

```bash
python tools/train.py projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py
```

For multi-gpu training, run:

```bash
python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=${NUM_GPUS} --master_port=29506 --master_addr="127.0.0.1" tools/train.py projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py
```

### Testing commands

In MMDetection3D's root directory, run the following command to test the model:

```bash
python tools/train.py projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py ${CHECKPOINT_PATH}
```

## Results and models

### Waymo

| Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download |
| :----------------------------------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [SECFPN_WithAttention](./configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py) | 5 | voxel (0.1) | ✓ | × | 14.8 | | 72.2 | 69.5 | 65.9 | 63.3 | [log](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/centerformer/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class_20221227_205613-70c9ad37.json) |

**Note** that `SECFPN_WithAttention` denotes both SECOND and SECONDFPN with ChannelAttention and SpatialAttention.

## Citation

```latex
@InProceedings{Zhou_centerformer,
title = {CenterFormer: Center-based Transformer for 3D Object Detection},
author = {Zhou, Zixiang and Zhao, Xiangchen and Wang, Yu and Wang, Panqu and Foroosh, Hassan},
booktitle = {ECCV},
year = {2022}
}
```
11 changes: 11 additions & 0 deletions projects/centerformer/centerformer/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from .bbox_ops import nms_iou3d
from .centerformer import CenterFormer
from .centerformer_backbone import (DeformableDecoderRPN,
MultiFrameDeformableDecoderRPN)
from .centerformer_head import CenterFormerBboxHead
from .losses import FastFocalLoss

__all__ = [
'CenterFormer', 'DeformableDecoderRPN', 'CenterFormerBboxHead',
'FastFocalLoss', 'nms_iou3d', 'MultiFrameDeformableDecoderRPN'
]
41 changes: 41 additions & 0 deletions projects/centerformer/centerformer/bbox_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import torch
from mmcv.utils import ext_loader

ext_module = ext_loader.load_ext('_ext', ['iou3d_nms3d_forward'])


def nms_iou3d(boxes, scores, thresh, pre_maxsize=None, post_max_size=None):
"""NMS function GPU implementation (using IoU3D) The difference of this
JingweiZhang12 marked this conversation as resolved.
Show resolved Hide resolved
implementation with nms3d in MMCV is that we add `pre_maxsize` and
`post_max_size` before and after NMS respectively.

Args:
boxes (Tensor): Input boxes with the shape of [N, 7]
([cx, cy, cz, l, w, h, theta]).
scores (Tensor): Scores of boxes with the shape of [N].
thresh (float): Overlap threshold of NMS.
pre_max_size (int, optional): Max size of boxes before NMS.
Defaults to None.
post_max_size (int, optional): Max size of boxes after NMS.
Defaults to None.

Returns:
Tensor: Indexes after NMS.
"""
# TODO: directly refactor ``nms3d`` in MMCV
assert boxes.size(1) == 7, 'Input boxes shape should be (N, 7)'
order = scores.sort(0, descending=True)[1]
if pre_maxsize is not None:
order = order[:pre_maxsize]
boxes = boxes[order].contiguous()

keep = boxes.new_zeros(boxes.size(0), dtype=torch.long)
num_out = boxes.new_zeros(size=(), dtype=torch.long)
ext_module.iou3d_nms3d_forward(
boxes, keep, num_out, nms_overlap_thresh=thresh)
keep = order[keep[:num_out].to(boxes.device)].contiguous()

if post_max_size is not None:
keep = keep[:post_max_size]

return keep
Loading