Skip to content

Commit

Permalink
[Feature] Support DGCNN (v1.0.0.dev0) (#896)
Browse files Browse the repository at this point in the history
* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* support dgcnn

* fix typo

* fix typo

* fix typo

* del gf&fa registry (wo reuse pointnet module)

* fix typo

* add benchmark and add copyright header (for DGCNN only)

* fix typo

* fix typo

* fix typo

* fix typo

* fix typo

* support dgcnn
  • Loading branch information
DCNSW authored Sep 3, 2021
1 parent d4b1244 commit f095eb6
Show file tree
Hide file tree
Showing 23 changed files with 953 additions and 44 deletions.
1 change: 1 addition & 0 deletions .dev_scripts/gather_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
'_6x_': 73,
'_50e_': 50,
'_80e_': 80,
'_100e_': 100,
'_150e_': 150,
'_200e_': 200,
'_250e_': 250,
Expand Down
41 changes: 22 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ Support backbones:
- [x] PointNet (CVPR'2017)
- [x] PointNet++ (NeurIPS'2017)
- [x] RegNet (CVPR'2020)
- [x] DGCNN (TOG'2019)

Support methods

Expand All @@ -94,25 +95,27 @@ Support methods
- [x] [Group-Free-3D (Arxiv'2021)](configs/groupfree3d/README.md)
- [x] [ImVoxelNet (Arxiv'2021)](configs/imvoxelnet/README.md)
- [x] [PAConv (CVPR'2021)](configs/paconv/README.md)

| | ResNet | ResNeXt | SENet |PointNet++ | HRNet | RegNetX | Res2Net |
|--------------------|:--------:|:--------:|:--------:|:---------:|:-----:|:--------:|:-----:|
| SECOND ||||||||
| PointPillars ||||||||
| FreeAnchor ||||||||
| VoteNet ||||||||
| H3DNet ||||||||
| 3DSSD ||||||||
| Part-A2 ||||||||
| MVXNet ||||||||
| CenterPoint ||||||||
| SSN ||||||||
| ImVoteNet ||||||||
| FCOS3D ||||||||
| PointNet++ ||||||||
| Group-Free-3D ||||||||
| ImVoxelNet ||||||||
| PAConv ||||||||
- [x] [DGCNN (TOG'2019)](configs/dgcnn/README.md)

| | ResNet | ResNeXt | SENet |PointNet++ |DGCNN | HRNet | RegNetX | Res2Net |
|--------------------|:--------:|:--------:|:--------:|:---------:|:---------:|:-----:|:--------:|:-----:|
| SECOND |||||||||
| PointPillars |||||||||
| FreeAnchor |||||||||
| VoteNet |||||||||
| H3DNet |||||||||
| 3DSSD |||||||||
| Part-A2 |||||||||
| MVXNet |||||||||
| CenterPoint |||||||||
| SSN |||||||||
| ImVoteNet |||||||||
| FCOS3D |||||||||
| PointNet++ |||||||||
| Group-Free-3D |||||||||
| ImVoxelNet |||||||||
| PAConv |||||||||
| DGCNN |||||||||

Other features
- [x] [Dynamic Voxelization](configs/dynamic_voxelization/README.md)
Expand Down
41 changes: 22 additions & 19 deletions README_zh-CN.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ MMDetection3D 是一个基于 PyTorch 的目标检测开源工具箱, 下一代
- [x] PointNet (CVPR'2017)
- [x] PointNet++ (NeurIPS'2017)
- [x] RegNet (CVPR'2020)
- [x] DGCNN (TOG'2019)

已支持的算法:

Expand All @@ -93,25 +94,27 @@ MMDetection3D 是一个基于 PyTorch 的目标检测开源工具箱, 下一代
- [x] [Group-Free-3D (Arxiv'2021)](configs/groupfree3d/README.md)
- [x] [ImVoxelNet (Arxiv'2021)](configs/imvoxelnet/README.md)
- [x] [PAConv (CVPR'2021)](configs/paconv/README.md)

| | ResNet | ResNeXt | SENet |PointNet++ | HRNet | RegNetX | Res2Net |
|--------------------|:--------:|:--------:|:--------:|:---------:|:-----:|:--------:|:-----:|
| SECOND ||||||||
| PointPillars ||||||||
| FreeAnchor ||||||||
| VoteNet ||||||||
| H3DNet ||||||||
| 3DSSD ||||||||
| Part-A2 ||||||||
| MVXNet ||||||||
| CenterPoint ||||||||
| SSN ||||||||
| ImVoteNet ||||||||
| FCOS3D ||||||||
| PointNet++ ||||||||
| Group-Free-3D ||||||||
| ImVoxelNet ||||||||
| PAConv ||||||||
- [x] [DGCNN (TOG'2019)](configs/dgcnn/README.md)

| | ResNet | ResNeXt | SENet |PointNet++ |DGCNN | HRNet | RegNetX | Res2Net |
|--------------------|:--------:|:--------:|:--------:|:---------:|:---------:|:-----:|:--------:|:-----:|
| SECOND |||||||||
| PointPillars |||||||||
| FreeAnchor |||||||||
| VoteNet |||||||||
| H3DNet |||||||||
| 3DSSD |||||||||
| Part-A2 |||||||||
| MVXNet |||||||||
| CenterPoint |||||||||
| SSN |||||||||
| ImVoteNet |||||||||
| FCOS3D |||||||||
| PointNet++ |||||||||
| Group-Free-3D |||||||||
| ImVoxelNet |||||||||
| PAConv |||||||||
| DGCNN |||||||||

其他特性
- [x] [Dynamic Voxelization](configs/dynamic_voxelization/README.md)
Expand Down
28 changes: 28 additions & 0 deletions configs/_base_/models/dgcnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# model settings
model = dict(
type='EncoderDecoder3D',
backbone=dict(
type='DGCNNBackbone',
in_channels=9, # [xyz, rgb, normal_xyz], modified with dataset
num_samples=(20, 20, 20),
knn_modes=('D-KNN', 'F-KNN', 'F-KNN'),
radius=(None, None, None),
gf_channels=((64, 64), (64, 64), (64, )),
fa_channels=(1024, ),
act_cfg=dict(type='LeakyReLU', negative_slope=0.2)),
decode_head=dict(
type='DGCNNHead',
fp_channels=(1216, 512),
channels=256,
dropout_ratio=0.5,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='LeakyReLU', negative_slope=0.2),
loss_decode=dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=None, # modified with dataset
loss_weight=1.0)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='slide'))
8 changes: 8 additions & 0 deletions configs/_base_/schedules/seg_cosine_100e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# optimizer
# This schedule is mainly used on S3DIS dataset in segmentation task
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
optimizer_config = dict(grad_clip=None)
lr_config = dict(policy='CosineAnnealing', warmup=None, min_lr=1e-5)

# runtime settings
runner = dict(type='EpochBasedRunner', max_epochs=100)
43 changes: 43 additions & 0 deletions configs/dgcnn/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Dynamic Graph CNN for Learning on Point Clouds

## Introduction

<!-- [ALGORITHM] -->

We implement DGCNN and provide the results and checkpoints on S3DIS dataset.

```
@article{dgcnn,
title={Dynamic Graph CNN for Learning on Point Clouds},
author={Wang, Yue and Sun, Yongbin and Liu, Ziwei and Sarma, Sanjay E. and Bronstein, Michael M. and Solomon, Justin M.},
journal={ACM Transactions on Graphics (TOG)},
year={2019}
}
```

**Notice**: We follow the implementations in the original DGCNN paper and a PyTorch implementation of DGCNN [code](https://github.com/AnTao97/dgcnn.pytorch).

## Results

### S3DIS

| Method | Split | Lr schd | Mem (GB) | Inf time (fps) | mIoU (Val set) | Download |
| :-------------------------------------------------------------------------: | :----: | :--------: | :------: | :------------: | :------------: | :----------------------: |
| [DGCNN](./dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py) | Area_1 | cosine 100e | 13.1 | | 68.33 | [model](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area1/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210731_000734-39658f14.pth) &#124; [log](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area1/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210731_000734.log.json) |
| [DGCNN](./dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py) | Area_2 | cosine 100e | 13.1 | | 40.68 | [model](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area2/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210731_144648-aea9ecb6.pth) &#124; [log](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area2/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210731_144648.log.json) |
| [DGCNN](./dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py) | Area_3 | cosine 100e | 13.1 | | 69.38 | [model](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area3/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210801_154629-2ff50ee0.pth) &#124; [log](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area3/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210801_154629.log.json) |
| [DGCNN](./dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py) | Area_4 | cosine 100e | 13.1 | | 50.07 | [model](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area4/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210802_073551-dffab9cd.pth) &#124; [log](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area4/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210802_073551.log.json) |
| [DGCNN](./dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py) | Area_5 | cosine 100e | 13.1 | | 50.59 | [model](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area5/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210730_235824-f277e0c5.pth) &#124; [log](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area5/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210730_235824.log.json) |
| [DGCNN](./dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py) | Area_6 | cosine 100e | 13.1 | | 77.94 | [model](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area6/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210802_154317-e3511b32.pth) &#124; [log](https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area6/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210802_154317.log.json) |
| [DGCNN](./dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py) | 6-fold | | | | 59.43 | |

**Notes:**

- We use XYZ+Color+Normalized_XYZ as input in all the experiments on S3DIS datasets.
- `Area_5` Split means training the model on Area_1, 2, 3, 4, 6 and testing on Area_5.
- `6-fold` Split means the overall result of 6 different splits (Area_1, Area_2, Area_3, Area_4, Area_5 and Area_6 Splits).
- Users need to modify `train_area` and `test_area` in the S3DIS dataset's [config](./configs/_base_/datasets/s3dis_seg-3d-13class.py) to set the training and testing areas, respectively.

## Indeterminism

Since DGCNN testing adopts sliding patch inference which involves random point sampling, and the test script uses fixed random seeds while the random seeds of validation in training are not fixed, the test results may be slightly different from the results reported above.
24 changes: 24 additions & 0 deletions configs/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
_base_ = [
'../_base_/datasets/s3dis_seg-3d-13class.py', '../_base_/models/dgcnn.py',
'../_base_/schedules/seg_cosine_100e.py', '../_base_/default_runtime.py'
]

# data settings
data = dict(samples_per_gpu=32)
evaluation = dict(interval=2)

# model settings
model = dict(
backbone=dict(in_channels=9), # [xyz, rgb, normalized_xyz]
decode_head=dict(
num_classes=13, ignore_index=13,
loss_decode=dict(class_weight=None)), # S3DIS doesn't use class_weight
test_cfg=dict(
num_points=4096,
block_size=1.0,
sample_rate=0.5,
use_normalized_coord=True,
batch_size=24))

# runtime settings
checkpoint_config = dict(interval=2)
24 changes: 24 additions & 0 deletions configs/dgcnn/metafile.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
Collections:
- Name: DGCNN
Metadata:
Training Techniques:
- SGD
Training Resources: 4x Titan XP GPUs
Architecture:
- DGCNN
Paper: https://arxiv.org/abs/1801.07829
README: configs/dgcnn/README.md

Models:
- Name: dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py
In Collection: DGCNN
Config: configs/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class.py
Metadata:
Training Data: S3DIS
Training Memory (GB): 13.3
Results:
- Task: 3D Semantic Segmentation
Dataset: S3DIS
Metrics:
mIoU: 50.59
Weights: https://download.openmmlab.com/mmdetection3d/v0.17.0_models/dgcnn/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class/area5/dgcnn_32x4_cosine_100e_s3dis_seg-3d-13class_20210730_235824-f277e0c5.pth
4 changes: 4 additions & 0 deletions docs/model_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,7 @@ Please refer to [ImVoxelNet](https://github.com/open-mmlab/mmdetection3d/blob/ma
### PAConv

Please refer to [PAConv](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/paconv) for details. We provide PAConv baselines on S3DIS dataset.

### DGCNN

Please refer to [DGCNN](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/dgcnn) for details. We provide DGCNN baselines on S3DIS dataset.
8 changes: 8 additions & 0 deletions docs_zh-CN/model_zoo.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,11 @@
### ImVoxelNet

请参考 [ImVoxelNet](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/imvoxelnet) 获取更多细节,我们在 KITTI 数据集上给出了相应的结果。

### PAConv

请参考 [PAConv](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/paconv) 获取更多细节,我们在 S3DIS 数据集上给出了相应的结果.

### DGCNN

请参考 [DGCNN](https://github.com/open-mmlab/mmdetection3d/blob/master/configs/dgcnn) 获取更多细节,我们在 S3DIS 数据集上给出了相应的结果.
4 changes: 3 additions & 1 deletion mmdet3d/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt
from .dgcnn import DGCNNBackbone
from .multi_backbone import MultiBackbone
from .nostem_regnet import NoStemRegNet
from .pointnet2_sa_msg import PointNet2SAMSG
Expand All @@ -8,5 +9,6 @@

__all__ = [
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
'SECOND', 'PointNet2SASSG', 'PointNet2SAMSG', 'MultiBackbone'
'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG',
'MultiBackbone'
]
Loading

0 comments on commit f095eb6

Please sign in to comment.