CasA
is a simple multi-stage 3D object detection framework based on a Cascade Attention design.
CasA
can be integrated into many SoTA 3D detectors and greatly improve their detection performance.
The paper of "CasA: A Cascade Attention Network for 3D Object Detection from LiDAR point clouds" can be found here.
This code is mostly built upon OpenPCDet. Note that, the CasA++ is based on a transfer learning framework: pre-training on Waymo and fine-tuning on KITTI. Since additional data has been included, we did not release the CasA++ codes.
Cascade frameworks have been widely studied in 2D object detection but less investigated in 3D space. Conventional cascade structures use multiple separate sub-networks to sequentially refine region proposals. Such methods, however, have limited ability to measure proposal quality in all stages, and hard to achieve a desirable detection performance improvement in 3D space. We propose a new cascade framework, termed CasA, for 3D object detection from point clouds. CasA consists of a Region Proposal Network (RPN) and a Cascade Refinement Network (CRN). In this CRN, we designed a new Cascade Attention Module that uses multiple sub-networks and attention modules to aggregate the object features from different stages and progressively refine region proposals. CasA can be integrated into various two-stage 3D detectors and greatly improve their detection performance. Extensive experimental results on KITTI and Waymo datasets with various baseline detectors demonstrate the universality and superiority of our CasA. In particular, based on one variant of Voxel-RCNN, we achieve state-of-the-art results on KITTI 3D object detection benchmark.
-
2022/10/15 Update a 3D multi-object tracker CasTrack based on the CasA detections, currently rank first on the KITTI tracking leader-board π₯!
-
2022/9/30 Update details of installation. Update environment we tested. Update Spconv2.X support π!
-
2022/3/3 Initial update, achieve SOTA performance on the KITTI 3D detection leader-board
The results are the 3D detection performance of moderate difficulty on the val set of KITTI dataset. Currently, this repo supports CasA-PV, CasA-V, CasA-T and CasA-PV2. The base detectors are PV-RCNN, Voxel-RCNN, CT3D and PV-RCNN++, respectively.
- All released models are trained with 2 3090 GPUs and are available for download.
- These models are not suitable to directly report results on KITTI test set, please use slightly lower score threshold and train the models on all or 80% training data to achieve a desirable performance on KITTI test set.
Detectors | Car(R11/R40) | Pedestrian(R11/R40) | Cyclist(R11/R40) | download |
---|---|---|---|---|
PV-RCNN baseline | 83.90/84.83 | 57.90/56.67 | 70.47/71.95 | |
CasA-PV | 86.18/85.86 | 58.90/59.17 | 66.01/69.09 | model-44M |
Detectors | Car(R11/R40) | Pedestrian(R11/R40) | Cyclist(R11/R40) | download |
---|---|---|---|---|
Voxel-RCNN baseline | 84.52/85.29 | 61.72/60.97 | 71.48/72.54 | |
CasA-V | 86.54/86.30 | 67.93/66.54 | 74.27/73.08 | model-44M |
Detectors | Car(R11/R40) | Pedestrian(R11/R40) | Cyclist(R11/R40) | download |
---|---|---|---|---|
CT3D3cat baseline | 84.97/85.04 | 56.28/55.58 | 71.71/71.88 | |
CasA-T | 86.76/86.44 | 60.91/62.53 | 73.36/71.83 | model-22M |
Detectors | Car(R11/R40) | Pedestrian(R11/R40) | Cyclist(R11/R40) | download |
---|---|---|---|---|
*PV-RCNN++ baseline | 85.36/85.50 | 57.43/57.15 | 71.30/71.85 | |
CasA-PV2 | 86.32/86.10 | 59.50/60.54 | 72.74/73.16 | model-47M |
Where * denodes reproduced results of a simplified version using their open-source codes.
Here we provided two models on WOD, where CasA-V-center denotes that the center-based RPN are used. All models are trained with a single frame on 8 V100 GPUs, and the results of each cell here are mAP/mAPH calculated by the official Waymo evaluation metrics on the whole validation set (version 1.2).
100% Data, 2 returns | Vec_L1 | Vec_L2 | Ped_L1 | Ped_L2 | Cyc_L1 | Cyc_L2 |
---|---|---|---|---|---|---|
*Voxel-RCNN baseline | 77.43/76.71 | 68.73/68.24 | 76.37/68.21 | 67.92/60.40 | 68.74/67.56 | 66.46/65.35 |
CasA-V | 78.54/78.00 | 69.91/69.42 | 80.88/73.10 | 71.87/64.78 | 69.66/68.38 | 67.07/66.83 |
CasA-V-Center | 78.62/78.04 | 69.94/69.47 | 81.76/75.69 | 72.75/67.21 | 72.47/71.18 | 70.20/68.94 |
Where * denodes reproduced results using their open-source codes.
We could not provide the above pretrained models due to Waymo Dataset License Agreement, but you could easily achieve similar performance by training with the default configs.
conda create -n spconv2 python=3.9
conda activate spconv2
pip install torch==1.8.1+cu111 torchvision==0.9.1+cu111 torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
pip install numpy==1.19.5 protobuf==3.19.4 scikit-image==0.19.2 waymo-open-dataset-tf-2-5-0 nuscenes-devkit==1.0.5 spconv-cu111 numba scipy pyyaml easydict fire tqdm shapely matplotlib opencv-python addict pyquaternion awscli open3d pandas future pybind11 tensorboardX tensorboard Cython prefetch-generator
Our released implementation is tested on.
- Ubuntu 18.04
- Python 3.6.9
- PyTorch 1.8.1
- Numba 0.53.1
- Spconv 1.2.1
- NVIDIA CUDA 11.1
- 8x Tesla V100 GPUs
We also tested on.
- Ubuntu 18.04
- Python 3.9.13
- PyTorch 1.8.1
- Numba 0.53.1
- Spconv 2.1.22 # pip install spconv-cu111
- NVIDIA CUDA 11.1
- 2x 3090 GPUs
- Please download the official KITTI 3D object detection dataset and organize the downloaded files as follows (the road planes could be downloaded from [road plane], which are optional for data augmentation in the training):
CasA
βββ data
β βββ kitti
β β βββ ImageSets
β β βββ training
β β β βββcalib & velodyne & label_2 & image_2 & (optional: planes)
β β βββ testing
β β β βββcalib & velodyne & image_2
βββ pcdet
βββ tools
Run following command to creat dataset infos:
python3 -m pcdet.datasets.kitti.kitti_dataset create_kitti_infos tools/cfgs/dataset_configs/kitti_dataset.yaml
CasA
βββ data
β βββ waymo
β β βββ ImageSets
β β βββ raw_data
β β β βββ segment-xxxxxxxx.tfrecord
| | | |ββ ...
| | |ββ waymo_processed_data_train_val_test
β β β βββ segment-xxxxxxxx/
| | | |ββ ...
β β βββ pcdet_waymo_track_dbinfos_train_cp.pkl
β β βββ waymo_infos_test.pkl
β β βββ waymo_infos_train.pkl
β β βββ waymo_infos_val.pkl
βββ pcdet
βββ tools
Run following command to creat dataset infos:
python3 -m pcdet.datasets.waymo.waymo_tracking_dataset --cfg_file tools/cfgs/dataset_configs/waymo_tracking_dataset.yaml
git clone https://github.com/hailanyi/CasA.git
cd CasA
python3 setup.py develop
cd tools
python3 test.py --cfg_file ${CONFIG_FILE} --batch_size ${BATCH_SIZE} --ckpt ${CKPT}
For example, if you test the CasA-V model:
cd tools
python3 test.py --cfg_file cfgs/kitti_models/CasA-V.yaml --ckpt CasA-V.pth
Multiple GPU test: you need modify the gpu number in the dist_test.sh and run
sh dist_test.sh
The log infos are saved into log-test.txt
You can run cat log-test.txt
to view the test results.
cd tools
python3 train.py --cfg_file ${CONFIG_FILE}
For example, if you train the CasA-V model:
cd tools
python3 train.py --cfg_file cfgs/kitti_models/CasA-V.yaml
Multiple GPU train: you can modify the gpu number in the dist_train.sh and run
sh dist_train.sh
The log infos are saved into log.txt
You can run cat log.txt
to view the training process.
This repo is developed from OpenPCDet 0.3
, we thank shaoshuai shi for his implementation of OpenPCDet.
If you find this project useful in your research, please consider cite:
@article{casa2022,
title={CasA: A Cascade Attention Network for 3D Object Detection from LiDAR point clouds},
author={Wu, Hai and Deng, Jinhao and Wen, Chenglu and Li, Xin and Wang, Cheng},
journal={IEEE Transactions on Geoscience and Remote Sensing},
year={2022}
}