Skip to content

Commit

Permalink
[Feature] Support training of BEVFusion (#2558)
Browse files Browse the repository at this point in the history
* support train on nus

* refactor transfusion head

* img branch optioinal

* support nuscenes_mini in replace_ceph_backend

* use replace_ceph

* add only-lidar

* use valid_flag in dataset filter

* support lidar-only training 69

* fix RTS

* fix rotation in ImgAug3D

* revert to original rotation in ImgAug3D

* add LSSDepthTransform and parse_losses

* fix LoadMultiSweeps

* fix bug about points in-place operations

* support amp and replace syncBN by BN

* add amp config

* set growth-interval in amp

* Revert "fix LoadMultiSweeps"

This reverts commit ab27ea1.

* add float in cls loss

* iter_based lr in fusion stage

* rename config

* use normalization query pos for stable training

* remove unnecessary code & simplify config & train 5 epoch

* smaller ete_min_ratio

* polish code

* fix UT

* Revert "use normalization query pos for stable training"

This reverts commit 3009118.

* update readme

* fix height offset
  • Loading branch information
JingweiZhang12 authored May 31, 2023
1 parent ed46b8c commit 583e907
Show file tree
Hide file tree
Showing 11 changed files with 693 additions and 249 deletions.
7 changes: 6 additions & 1 deletion mmdet3d/engine/hooks/disable_object_sample_hook.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset import BaseDataset
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.runner import Runner
Expand Down Expand Up @@ -35,7 +36,11 @@ def before_train_epoch(self, runner: Runner):
model = model.module
if epoch == self.disable_after_epoch:
runner.logger.info('Disable ObjectSample')
for transform in runner.train_dataloader.dataset.pipeline.transforms: # noqa: E501
dataset = runner.train_dataloader.dataset
# handle dataset wrapper
if not isinstance(dataset, BaseDataset):
dataset = dataset.dataset
for transform in dataset.pipeline.transforms: # noqa: E501
if isinstance(transform, ObjectSample):
assert hasattr(transform, 'disabled')
transform.disabled = True
Expand Down
27 changes: 15 additions & 12 deletions projects/BEVFusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ results is available at https://github.com/mit-han-lab/bevfusion.

## Introduction

We implement BEVFusion and provide the results and pretrained checkpoints on NuScenes dataset.
We implement BEVFusion and support training and testing on NuScenes dataset.

## Usage

Expand All @@ -34,38 +34,41 @@ python projects/BEVFusion/setup.py develop
Run a demo on NuScenes data using [BEVFusion model](https://drive.google.com/file/d/1QkvbYDk4G2d6SZoeJqish13qSyXA4lp3/view?usp=share_link):

```shell
python demo/multi_modality_demo.py demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__LIDAR_TOP__1532402927647951.pcd.bin demo/data/nuscenes/ demo/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py ${CHECKPOINT_FILE} --cam-type all --score-thr 0.2 --show
python demo/multi_modality_demo.py demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__LIDAR_TOP__1532402927647951.pcd.bin demo/data/nuscenes/ demo/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl projects/BEVFusion/configs/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py ${CHECKPOINT_FILE} --cam-type all --score-thr 0.2 --show
```

### Training commands

In MMDetection3D's root directory, run the following command to train the model:
1. You should train the lidar-only detector first:

```bash
python tools/train.py projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py
bash tools/dist_train.py projects/BEVFusion/configs/bevfusion_lidar_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py 8
```

For multi-gpu training, run:
2. Download the [Swin pre-trained model](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/bevfusion/swint-nuimages-pretrained.pth). Given the image pre-trained backbone and the lidar-only pre-trained detector, you could train the lidar-camera fusion model:

```bash
python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=${NUM_GPUS} --master_port=29506 --master_addr="127.0.0.1" tools/train.py projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py
bash tools/dist_train.sh projects/BEVFusion/configs/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py 8 --cfg-options load_from=${LIDAR_PRETRAINED_CHECKPOINT} model.img_backbone.init_cfg.checkpoint=${IMAGE_PRETRAINED_BACKBONE}
```

**Note** that if you want to reduce CUDA memory usage and computational overhead, you could directly add `--amp` on the tail of the above commands. The model under this setting will be trained in fp16 mode.

### Testing commands

In MMDetection3D's root directory, run the following command to test the model:

```bash
python tools/test.py projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py ${CHECKPOINT_PATH}
bash tools/dist_test.sh projects/BEVFusion/configs/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py ${CHECKPOINT_PATH} 8
```

## Results and models

### NuScenes

| Backbone | Voxel type (voxel size) | NMS | Mem (GB) | Inf time (fps) | NDS | mAP | Download |
| :-----------------------------------------------------------------------------: | :---------------------: | :-: | :------: | :------------: | :---: | :---: | :------------------------------------------------------------------------------------------------------: |
| [SECFPN](./configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py) | voxel (0.075) | × | - | - | 71.62 | 68.77 | [converted_model](https://drive.google.com/file/d/1QkvbYDk4G2d6SZoeJqish13qSyXA4lp3/view?usp=share_link) |
| Modality | Voxel type (voxel size) | NMS | Mem (GB) | Inf time (fps) | NDS | mAP | Download |
| :------------------------------------------------------------------------------------------: | :---------------------: | :-: | :------: | :------------: | :--: | :--: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| [lidar](./configs/bevfusion_lidar_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py) | voxel (0.075) | × | - | - | 69.6 | 64.9 | [model](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/bevfusion/bevfusion_lidar_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d-2628f933.pth) [logs](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/bevfusion/bevfusion_lidar_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d_20230322_053447.log) |
| [lidar-cam](./configs/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py) | voxel (0.075) | × | - | - | 71.4 | 68.6 | [model](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/bevfusion/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d-5239b1af.pth) [logs](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/bevfusion/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d_20230524_001539.log) |

## Citation

Expand Down Expand Up @@ -103,9 +106,9 @@ A project does not necessarily have to be finished in a single PR, but it's esse

<!-- As this template does. -->

- [ ] Milestone 2: Indicates a successful model implementation.
- [x] Milestone 2: Indicates a successful model implementation.

- [ ] Training-time correctness
- [x] Training-time correctness

<!-- If you are reproducing the result from a paper, checking this item means that you should have trained your model from scratch based on the original paper's specification and verified that the final result matches the report within a minor error range. -->

Expand Down
10 changes: 6 additions & 4 deletions projects/BEVFusion/bevfusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from .bevfusion import BEVFusion
from .bevfusion_necks import GeneralizedLSSFPN
from .depth_lss import DepthLSSTransform
from .depth_lss import DepthLSSTransform, LSSTransform
from .loading import BEVLoadMultiViewImageFromFiles
from .sparse_encoder import BEVFusionSparseEncoder
from .transformer import TransformerDecoderLayer
from .transforms_3d import GridMask, ImageAug3D
from .transforms_3d import (BEVFusionGlobalRotScaleTrans,
BEVFusionRandomFlip3D, GridMask, ImageAug3D)
from .transfusion_head import ConvFuser, TransFusionHead
from .utils import (BBoxBEVL1Cost, HeuristicAssigner3D, HungarianAssigner3D,
IoU3DCost)

__all__ = [
'BEVFusion', 'TransFusionHead', 'ConvFuser', 'ImageAug3D', 'GridMask',
'GeneralizedLSSFPN', 'HungarianAssigner3D', 'BBoxBEVL1Cost', 'IoU3DCost',
'HeuristicAssigner3D', 'DepthLSSTransform',
'HeuristicAssigner3D', 'DepthLSSTransform', 'LSSTransform',
'BEVLoadMultiViewImageFromFiles', 'BEVFusionSparseEncoder',
'TransformerDecoderLayer'
'TransformerDecoderLayer', 'BEVFusionRandomFlip3D',
'BEVFusionGlobalRotScaleTrans'
]
154 changes: 105 additions & 49 deletions projects/BEVFusion/bevfusion/bevfusion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Dict, List, Optional
from collections import OrderedDict
from copy import deepcopy
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.distributed as dist
from mmengine.utils import is_list_of
from torch import Tensor
from torch.nn import functional as F

Expand All @@ -23,7 +27,7 @@ def __init__(
fusion_layer: Optional[dict] = None,
img_backbone: Optional[dict] = None,
pts_backbone: Optional[dict] = None,
vtransform: Optional[dict] = None,
view_transform: Optional[dict] = None,
img_neck: Optional[dict] = None,
pts_neck: Optional[dict] = None,
bbox_head: Optional[dict] = None,
Expand All @@ -40,20 +44,21 @@ def __init__(

self.pts_voxel_encoder = MODELS.build(pts_voxel_encoder)

self.img_backbone = MODELS.build(img_backbone)
self.img_neck = MODELS.build(img_neck)
self.vtransform = MODELS.build(vtransform)
self.img_backbone = MODELS.build(
img_backbone) if img_backbone is not None else None
self.img_neck = MODELS.build(
img_neck) if img_neck is not None else None
self.view_transform = MODELS.build(
view_transform) if view_transform is not None else None
self.pts_middle_encoder = MODELS.build(pts_middle_encoder)

self.fusion_layer = MODELS.build(fusion_layer)
self.fusion_layer = MODELS.build(
fusion_layer) if fusion_layer is not None else None

self.pts_backbone = MODELS.build(pts_backbone)
self.pts_neck = MODELS.build(pts_neck)

self.bbox_head = MODELS.build(bbox_head)
# hard code here where using converted checkpoint of original
# implementation of `BEVFusion`
self.use_converted_checkpoint = True

self.init_weights()

Expand All @@ -67,6 +72,46 @@ def _forward(self,
"""
pass

def parse_losses(
self, losses: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Parses the raw outputs (losses) of the network.
Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.
Returns:
tuple[Tensor, dict]: There are two elements. The first is the
loss tensor passed to optim_wrapper which may be a weighted sum
of all losses, and the second is log_vars which will be sent to
the logger.
"""
log_vars = []
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars.append([loss_name, loss_value.mean()])
elif is_list_of(loss_value, torch.Tensor):
log_vars.append(
[loss_name,
sum(_loss.mean() for _loss in loss_value)])
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')

loss = sum(value for key, value in log_vars if 'loss' in key)
log_vars.insert(0, ['loss', loss])
log_vars = OrderedDict(log_vars) # type: ignore

for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()

return loss, log_vars # type: ignore

def init_weights(self) -> None:
if self.img_backbone is not None:
self.img_backbone.init_weights()
Expand Down Expand Up @@ -94,7 +139,7 @@ def extract_img_feat(
img_metas,
) -> torch.Tensor:
B, N, C, H, W = x.size()
x = x.view(B * N, C, H, W)
x = x.view(B * N, C, H, W).contiguous()

x = self.img_backbone(x)
x = self.img_neck(x)
Expand All @@ -105,22 +150,25 @@ def extract_img_feat(
BN, C, H, W = x.size()
x = x.view(B, int(BN / B), C, H, W)

x = self.vtransform(
x,
points,
lidar2image,
camera_intrinsics,
camera2lidar,
img_aug_matrix,
lidar_aug_matrix,
img_metas,
)
with torch.autocast(device_type='cuda', dtype=torch.float32):
x = self.view_transform(
x,
points,
lidar2image,
camera_intrinsics,
camera2lidar,
img_aug_matrix,
lidar_aug_matrix,
img_metas,
)
return x

def extract_pts_feat(self, batch_inputs_dict) -> torch.Tensor:
points = batch_inputs_dict['points']
feats, coords, sizes = self.voxelize(points)
batch_size = coords[-1, 0] + 1
with torch.autocast('cuda', enabled=False):
points = [point.float() for point in points]
feats, coords, sizes = self.voxelize(points)
batch_size = coords[-1, 0] + 1
x = self.pts_middle_encoder(feats, coords, batch_size)
return x

Expand Down Expand Up @@ -184,11 +232,6 @@ def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],

if self.with_bbox_head:
outputs = self.bbox_head.predict(feats, batch_input_metas)
if self.use_converted_checkpoint:
outputs[0]['bboxes_3d'].tensor[:, 6] = -outputs[0][
'bboxes_3d'].tensor[:, 6] - np.pi / 2
outputs[0]['bboxes_3d'].tensor[:, 3:5] = outputs[0][
'bboxes_3d'].tensor[:, [4, 3]]

res = self.add_pred_to_datasample(batch_data_samples, outputs)

Expand All @@ -202,28 +245,32 @@ def extract_feat(
):
imgs = batch_inputs_dict.get('imgs', None)
points = batch_inputs_dict.get('points', None)

lidar2image, camera_intrinsics, camera2lidar = [], [], []
img_aug_matrix, lidar_aug_matrix = [], []
for i, meta in enumerate(batch_input_metas):
lidar2image.append(meta['lidar2img'])
camera_intrinsics.append(meta['cam2img'])
camera2lidar.append(meta['cam2lidar'])
img_aug_matrix.append(meta.get('img_aug_matrix', np.eye(4)))
lidar_aug_matrix.append(meta.get('lidar_aug_matrix', np.eye(4)))

lidar2image = imgs.new_tensor(np.asarray(lidar2image))
camera_intrinsics = imgs.new_tensor(np.array(camera_intrinsics))
camera2lidar = imgs.new_tensor(np.asarray(camera2lidar))
img_aug_matrix = imgs.new_tensor(np.asarray(img_aug_matrix))
lidar_aug_matrix = imgs.new_tensor(np.asarray(lidar_aug_matrix))
img_feature = self.extract_img_feat(imgs, points, lidar2image,
camera_intrinsics, camera2lidar,
img_aug_matrix, lidar_aug_matrix,
batch_input_metas)
features = []
if imgs is not None:
imgs = imgs.contiguous()
lidar2image, camera_intrinsics, camera2lidar = [], [], []
img_aug_matrix, lidar_aug_matrix = [], []
for i, meta in enumerate(batch_input_metas):
lidar2image.append(meta['lidar2img'])
camera_intrinsics.append(meta['cam2img'])
camera2lidar.append(meta['cam2lidar'])
img_aug_matrix.append(meta.get('img_aug_matrix', np.eye(4)))
lidar_aug_matrix.append(
meta.get('lidar_aug_matrix', np.eye(4)))

lidar2image = imgs.new_tensor(np.asarray(lidar2image))
camera_intrinsics = imgs.new_tensor(np.array(camera_intrinsics))
camera2lidar = imgs.new_tensor(np.asarray(camera2lidar))
img_aug_matrix = imgs.new_tensor(np.asarray(img_aug_matrix))
lidar_aug_matrix = imgs.new_tensor(np.asarray(lidar_aug_matrix))
img_feature = self.extract_img_feat(imgs, deepcopy(points),
lidar2image, camera_intrinsics,
camera2lidar, img_aug_matrix,
lidar_aug_matrix,
batch_input_metas)
features.append(img_feature)
pts_feature = self.extract_pts_feat(batch_inputs_dict)

features = [img_feature, pts_feature]
features.append(pts_feature)

if self.fusion_layer is not None:
x = self.fusion_layer(features)
Expand All @@ -239,4 +286,13 @@ def extract_feat(
def loss(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
batch_data_samples: List[Det3DDataSample],
**kwargs) -> List[Det3DDataSample]:
pass
batch_input_metas = [item.metainfo for item in batch_data_samples]
feats = self.extract_feat(batch_inputs_dict, batch_input_metas)

losses = dict()
if self.with_bbox_head:
bbox_loss = self.bbox_head.loss(feats, batch_data_samples)

losses.update(bbox_loss)

return losses
Loading

0 comments on commit 583e907

Please sign in to comment.