Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support training of BEVFusion #2558

Merged
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
0784667
support train on nus
JingweiZhang12 Mar 17, 2023
2dbf063
refactor transfusion head
JingweiZhang12 Mar 20, 2023
149150d
img branch optioinal
JingweiZhang12 Mar 20, 2023
1a9c681
support nuscenes_mini in replace_ceph_backend
JingweiZhang12 Mar 20, 2023
eb8c69d
use replace_ceph
JingweiZhang12 Mar 20, 2023
4223023
add only-lidar
JingweiZhang12 Mar 20, 2023
5c5d555
use valid_flag in dataset filter
JingweiZhang12 Mar 21, 2023
8078f57
support lidar-only training 69
JingweiZhang12 Mar 24, 2023
68e4f31
fix RTS
JingweiZhang12 Apr 18, 2023
cf39e07
fix rotation in ImgAug3D
JingweiZhang12 Apr 24, 2023
710a23d
revert to original rotation in ImgAug3D
JingweiZhang12 May 9, 2023
041b288
add LSSDepthTransform and parse_losses
JingweiZhang12 May 9, 2023
ab27ea1
fix LoadMultiSweeps
JingweiZhang12 May 16, 2023
40a97e2
fix bug about points in-place operations
JingweiZhang12 May 16, 2023
f17d03b
support amp and replace syncBN by BN
JingweiZhang12 May 17, 2023
dc4b7be
add amp config
JingweiZhang12 May 18, 2023
a059517
set growth-interval in amp
JingweiZhang12 May 19, 2023
9457729
Revert "fix LoadMultiSweeps"
JingweiZhang12 May 20, 2023
836c775
add float in cls loss
JingweiZhang12 May 23, 2023
839f05e
iter_based lr in fusion stage
JingweiZhang12 May 24, 2023
0b32c0a
rename config
JingweiZhang12 May 26, 2023
3009118
use normalization query pos for stable training
JingweiZhang12 May 26, 2023
5fee2e0
remove unnecessary code & simplify config & train 5 epoch
JingweiZhang12 May 26, 2023
e656533
Merge branch 'dev-1.x' of https://github.com/open-mmlab/mmdetection3d…
JingweiZhang12 May 26, 2023
51f5985
smaller ete_min_ratio
JingweiZhang12 May 29, 2023
f22c0c3
Merge branch 'dev-1.x' of github.com:open-mmlab/mmdetection3d into be…
JingweiZhang12 May 29, 2023
09311fe
polish code
JingweiZhang12 May 29, 2023
9ac3e03
fix UT
JingweiZhang12 May 30, 2023
411105b
Revert "use normalization query pos for stable training"
JingweiZhang12 May 30, 2023
e7c12e4
Merge branch 'bevfusion_oldsweep_fixpos' of https://github.com/Jingwe…
JingweiZhang12 May 30, 2023
3678fb4
update readme
JingweiZhang12 May 31, 2023
e127201
fix height offset
JingweiZhang12 May 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion mmdet3d/engine/hooks/disable_object_sample_hook.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine.dataset import BaseDataset
from mmengine.hooks import Hook
from mmengine.model import is_model_wrapper
from mmengine.runner import Runner
Expand Down Expand Up @@ -35,7 +36,11 @@ def before_train_epoch(self, runner: Runner):
model = model.module
if epoch == self.disable_after_epoch:
runner.logger.info('Disable ObjectSample')
for transform in runner.train_dataloader.dataset.pipeline.transforms: # noqa: E501
dataset = runner.train_dataloader.dataset
# handle dataset wrapper
if not isinstance(dataset, BaseDataset):
dataset = dataset.dataset
for transform in dataset.pipeline.transforms: # noqa: E501
if isinstance(transform, ObjectSample):
assert hasattr(transform, 'disabled')
transform.disabled = True
Expand Down
18 changes: 9 additions & 9 deletions projects/BEVFusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ results is available at https://github.com/mit-han-lab/bevfusion.

## Introduction

We implement BEVFusion and provide the results and pretrained checkpoints on NuScenes dataset.
We implement BEVFusion and support training and testing on NuScenes dataset.

## Usage

Expand All @@ -34,29 +34,29 @@ python projects/BEVFusion/setup.py develop
Run a demo on NuScenes data using [BEVFusion model](https://drive.google.com/file/d/1QkvbYDk4G2d6SZoeJqish13qSyXA4lp3/view?usp=share_link):

```shell
python demo/multi_modality_demo.py demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__LIDAR_TOP__1532402927647951.pcd.bin demo/data/nuscenes/ demo/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py ${CHECKPOINT_FILE} --cam-type all --score-thr 0.2 --show
python demo/multi_modality_demo.py demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__LIDAR_TOP__1532402927647951.pcd.bin demo/data/nuscenes/ demo/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl projects/BEVFusion/configs/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py ${CHECKPOINT_FILE} --cam-type all --score-thr 0.2 --show
```

### Training commands

In MMDetection3D's root directory, run the following command to train the model:
1. You should train the lidar-only detector first:

```bash
python tools/train.py projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py
python tools/train.py projects/BEVFusion/configs/bevfusion_lidar_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py
```

For multi-gpu training, run:
2. Download the [Swin pre-trained model](<>). Given the image pre-trained backbone and the lidar-only pre-trained detector, you could train the lidar-camera fusion model:

```bash
python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=${NUM_GPUS} --master_port=29506 --master_addr="127.0.0.1" tools/train.py projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py
python tools/train.py projects/BEVFusion/configs/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py --cfg-options load_from=${LIDAR_PRETRAINED_CHECKPOINT} model.img_backbone.init_cfg.checkpoint=${IMAGE_PRETRAINED_BACKBONE}
```

### Testing commands

In MMDetection3D's root directory, run the following command to test the model:

```bash
python tools/test.py projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py ${CHECKPOINT_PATH}
python tools/test.py projects/BEVFusion/configs/bevfusion_lidar-cam_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py ${CHECKPOINT_PATH}
```

## Results and models
Expand Down Expand Up @@ -103,9 +103,9 @@ A project does not necessarily have to be finished in a single PR, but it's esse

<!-- As this template does. -->

- [ ] Milestone 2: Indicates a successful model implementation.
- [x] Milestone 2: Indicates a successful model implementation.

- [ ] Training-time correctness
- [x] Training-time correctness

<!-- If you are reproducing the result from a paper, checking this item means that you should have trained your model from scratch based on the original paper's specification and verified that the final result matches the report within a minor error range. -->

Expand Down
10 changes: 6 additions & 4 deletions projects/BEVFusion/bevfusion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
from .bevfusion import BEVFusion
from .bevfusion_necks import GeneralizedLSSFPN
from .depth_lss import DepthLSSTransform
from .depth_lss import DepthLSSTransform, LSSTransform
from .loading import BEVLoadMultiViewImageFromFiles
from .sparse_encoder import BEVFusionSparseEncoder
from .transformer import TransformerDecoderLayer
from .transforms_3d import GridMask, ImageAug3D
from .transforms_3d import (BEVFusionGlobalRotScaleTrans,
BEVFusionRandomFlip3D, GridMask, ImageAug3D)
from .transfusion_head import ConvFuser, TransFusionHead
from .utils import (BBoxBEVL1Cost, HeuristicAssigner3D, HungarianAssigner3D,
IoU3DCost)

__all__ = [
'BEVFusion', 'TransFusionHead', 'ConvFuser', 'ImageAug3D', 'GridMask',
'GeneralizedLSSFPN', 'HungarianAssigner3D', 'BBoxBEVL1Cost', 'IoU3DCost',
'HeuristicAssigner3D', 'DepthLSSTransform',
'HeuristicAssigner3D', 'DepthLSSTransform', 'LSSTransform',
'BEVLoadMultiViewImageFromFiles', 'BEVFusionSparseEncoder',
'TransformerDecoderLayer'
'TransformerDecoderLayer', 'BEVFusionRandomFlip3D',
'BEVFusionGlobalRotScaleTrans'
]
154 changes: 105 additions & 49 deletions projects/BEVFusion/bevfusion/bevfusion.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Dict, List, Optional
from collections import OrderedDict
from copy import deepcopy
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
import torch.distributed as dist
from mmengine.utils import is_list_of
from torch import Tensor
from torch.nn import functional as F

Expand All @@ -23,7 +27,7 @@ def __init__(
fusion_layer: Optional[dict] = None,
img_backbone: Optional[dict] = None,
pts_backbone: Optional[dict] = None,
vtransform: Optional[dict] = None,
view_transform: Optional[dict] = None,
img_neck: Optional[dict] = None,
pts_neck: Optional[dict] = None,
bbox_head: Optional[dict] = None,
Expand All @@ -40,20 +44,21 @@ def __init__(

self.pts_voxel_encoder = MODELS.build(pts_voxel_encoder)

self.img_backbone = MODELS.build(img_backbone)
self.img_neck = MODELS.build(img_neck)
self.vtransform = MODELS.build(vtransform)
self.img_backbone = MODELS.build(
img_backbone) if img_backbone is not None else None
self.img_neck = MODELS.build(
img_neck) if img_neck is not None else None
self.view_transform = MODELS.build(
view_transform) if view_transform is not None else None
self.pts_middle_encoder = MODELS.build(pts_middle_encoder)

self.fusion_layer = MODELS.build(fusion_layer)
self.fusion_layer = MODELS.build(
fusion_layer) if fusion_layer is not None else None

self.pts_backbone = MODELS.build(pts_backbone)
self.pts_neck = MODELS.build(pts_neck)

self.bbox_head = MODELS.build(bbox_head)
# hard code here where using converted checkpoint of original
# implementation of `BEVFusion`
self.use_converted_checkpoint = True

self.init_weights()

Expand All @@ -67,6 +72,46 @@ def _forward(self,
"""
pass

def parse_losses(
self, losses: Dict[str, torch.Tensor]
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""Parses the raw outputs (losses) of the network.

Args:
losses (dict): Raw output of the network, which usually contain
losses and other necessary information.

Returns:
tuple[Tensor, dict]: There are two elements. The first is the
loss tensor passed to optim_wrapper which may be a weighted sum
of all losses, and the second is log_vars which will be sent to
the logger.
"""
log_vars = []
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars.append([loss_name, loss_value.mean()])
elif is_list_of(loss_value, torch.Tensor):
log_vars.append(
[loss_name,
sum(_loss.mean() for _loss in loss_value)])
else:
raise TypeError(
f'{loss_name} is not a tensor or list of tensors')

loss = sum(value for key, value in log_vars if 'loss' in key)
log_vars.insert(0, ['loss', loss])
log_vars = OrderedDict(log_vars) # type: ignore

for loss_name, loss_value in log_vars.items():
# reduce loss when distributed training
if dist.is_available() and dist.is_initialized():
loss_value = loss_value.data.clone()
dist.all_reduce(loss_value.div_(dist.get_world_size()))
log_vars[loss_name] = loss_value.item()

return loss, log_vars # type: ignore

def init_weights(self) -> None:
if self.img_backbone is not None:
self.img_backbone.init_weights()
Expand Down Expand Up @@ -94,7 +139,7 @@ def extract_img_feat(
img_metas,
) -> torch.Tensor:
B, N, C, H, W = x.size()
x = x.view(B * N, C, H, W)
x = x.view(B * N, C, H, W).contiguous()

x = self.img_backbone(x)
x = self.img_neck(x)
Expand All @@ -105,22 +150,25 @@ def extract_img_feat(
BN, C, H, W = x.size()
x = x.view(B, int(BN / B), C, H, W)

x = self.vtransform(
x,
points,
lidar2image,
camera_intrinsics,
camera2lidar,
img_aug_matrix,
lidar_aug_matrix,
img_metas,
)
with torch.autocast(device_type='cuda', dtype=torch.float32):
x = self.view_transform(
x,
points,
lidar2image,
camera_intrinsics,
camera2lidar,
img_aug_matrix,
lidar_aug_matrix,
img_metas,
)
return x

def extract_pts_feat(self, batch_inputs_dict) -> torch.Tensor:
points = batch_inputs_dict['points']
feats, coords, sizes = self.voxelize(points)
batch_size = coords[-1, 0] + 1
with torch.autocast('cuda', enabled=False):
points = [point.float() for point in points]
feats, coords, sizes = self.voxelize(points)
batch_size = coords[-1, 0] + 1
x = self.pts_middle_encoder(feats, coords, batch_size)
return x

Expand Down Expand Up @@ -184,11 +232,6 @@ def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]],

if self.with_bbox_head:
outputs = self.bbox_head.predict(feats, batch_input_metas)
if self.use_converted_checkpoint:
outputs[0]['bboxes_3d'].tensor[:, 6] = -outputs[0][
'bboxes_3d'].tensor[:, 6] - np.pi / 2
outputs[0]['bboxes_3d'].tensor[:, 3:5] = outputs[0][
'bboxes_3d'].tensor[:, [4, 3]]

res = self.add_pred_to_datasample(batch_data_samples, outputs)

Expand All @@ -202,28 +245,32 @@ def extract_feat(
):
imgs = batch_inputs_dict.get('imgs', None)
points = batch_inputs_dict.get('points', None)

lidar2image, camera_intrinsics, camera2lidar = [], [], []
img_aug_matrix, lidar_aug_matrix = [], []
for i, meta in enumerate(batch_input_metas):
lidar2image.append(meta['lidar2img'])
camera_intrinsics.append(meta['cam2img'])
camera2lidar.append(meta['cam2lidar'])
img_aug_matrix.append(meta.get('img_aug_matrix', np.eye(4)))
lidar_aug_matrix.append(meta.get('lidar_aug_matrix', np.eye(4)))

lidar2image = imgs.new_tensor(np.asarray(lidar2image))
camera_intrinsics = imgs.new_tensor(np.array(camera_intrinsics))
camera2lidar = imgs.new_tensor(np.asarray(camera2lidar))
img_aug_matrix = imgs.new_tensor(np.asarray(img_aug_matrix))
lidar_aug_matrix = imgs.new_tensor(np.asarray(lidar_aug_matrix))
img_feature = self.extract_img_feat(imgs, points, lidar2image,
camera_intrinsics, camera2lidar,
img_aug_matrix, lidar_aug_matrix,
batch_input_metas)
features = []
if imgs is not None:
imgs = imgs.contiguous()
lidar2image, camera_intrinsics, camera2lidar = [], [], []
img_aug_matrix, lidar_aug_matrix = [], []
for i, meta in enumerate(batch_input_metas):
lidar2image.append(meta['lidar2img'])
camera_intrinsics.append(meta['cam2img'])
camera2lidar.append(meta['cam2lidar'])
img_aug_matrix.append(meta.get('img_aug_matrix', np.eye(4)))
lidar_aug_matrix.append(
meta.get('lidar_aug_matrix', np.eye(4)))

lidar2image = imgs.new_tensor(np.asarray(lidar2image))
camera_intrinsics = imgs.new_tensor(np.array(camera_intrinsics))
camera2lidar = imgs.new_tensor(np.asarray(camera2lidar))
img_aug_matrix = imgs.new_tensor(np.asarray(img_aug_matrix))
lidar_aug_matrix = imgs.new_tensor(np.asarray(lidar_aug_matrix))
img_feature = self.extract_img_feat(imgs, deepcopy(points),
lidar2image, camera_intrinsics,
camera2lidar, img_aug_matrix,
lidar_aug_matrix,
batch_input_metas)
features.append(img_feature)
pts_feature = self.extract_pts_feat(batch_inputs_dict)

features = [img_feature, pts_feature]
features.append(pts_feature)

if self.fusion_layer is not None:
x = self.fusion_layer(features)
Expand All @@ -239,4 +286,13 @@ def extract_feat(
def loss(self, batch_inputs_dict: Dict[str, Optional[Tensor]],
batch_data_samples: List[Det3DDataSample],
**kwargs) -> List[Det3DDataSample]:
pass
batch_input_metas = [item.metainfo for item in batch_data_samples]
feats = self.extract_feat(batch_inputs_dict, batch_input_metas)

losses = dict()
if self.with_bbox_head:
bbox_loss = self.bbox_head.loss(feats, batch_data_samples)

losses.update(bbox_loss)

return losses
Loading