diff --git a/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl b/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl index 95b51083a5..60f56360cd 100644 Binary files a/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl and b/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl differ diff --git a/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__CAM_BACK_LEFT__1532402927647423.jpg b/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__CAM_BACK_LEFT__1532402927647423.jpg new file mode 100644 index 0000000000..db935ab150 Binary files /dev/null and b/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__CAM_BACK_LEFT__1532402927647423.jpg differ diff --git a/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__CAM_BACK_RIGHT__1532402927627893.jpg b/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__CAM_BACK_RIGHT__1532402927627893.jpg new file mode 100644 index 0000000000..9542920614 Binary files /dev/null and b/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__CAM_BACK_RIGHT__1532402927627893.jpg differ diff --git a/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__CAM_FRONT_LEFT__1532402927604844.jpg b/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__CAM_FRONT_LEFT__1532402927604844.jpg new file mode 100644 index 0000000000..2197221b4b Binary files /dev/null and b/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__CAM_FRONT_LEFT__1532402927604844.jpg differ diff --git a/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__CAM_FRONT_RIGHT__1532402927620339.jpg b/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__CAM_FRONT_RIGHT__1532402927620339.jpg new file mode 100644 index 0000000000..6ab0a63161 Binary files /dev/null and b/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__CAM_FRONT_RIGHT__1532402927620339.jpg differ diff --git a/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__CAM_FRONT__1532402927612460.jpg b/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__CAM_FRONT__1532402927612460.jpg new file mode 100644 index 0000000000..5123fabb2f Binary files /dev/null and b/demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__CAM_FRONT__1532402927612460.jpg differ diff --git a/demo/multi_modality_demo.py b/demo/multi_modality_demo.py index c5486bf9c9..7d7680197e 100644 --- a/demo/multi_modality_demo.py +++ b/demo/multi_modality_demo.py @@ -49,8 +49,15 @@ def main(args): result, data = inference_multi_modality_detector(model, args.pcd, args.img, args.ann, args.cam_type) points = data['inputs']['points'] - img = mmcv.imread(args.img) - img = mmcv.imconvert(img, 'bgr', 'rgb') + if isinstance(result.img_path, list): + img = [] + for img_path in result.img_path: + single_img = mmcv.imread(img_path) + single_img = mmcv.imconvert(single_img, 'bgr', 'rgb') + img.append(single_img) + else: + img = mmcv.imread(result.img_path) + img = mmcv.imconvert(img, 'bgr', 'rgb') data_input = dict(points=points, img=img) # show the results diff --git a/docs/en/user_guides/inference.md b/docs/en/user_guides/inference.md index e52c0b0f66..68570ef84d 100644 --- a/docs/en/user_guides/inference.md +++ b/docs/en/user_guides/inference.md @@ -78,6 +78,12 @@ Example on SUN RGB-D data using [ImVoteNet model](https://download.openmmlab.com python demo/multi_modality_demo.py demo/data/sunrgbd/000017.bin demo/data/sunrgbd/000017.jpg demo/data/sunrgbd/sunrgbd_000017_infos.pkl configs/imvotenet/imvotenet_stage2_8xb16_sunrgbd-3d.py ${CHECKPOINT_FILE} --cam-type CAM0 --show --score-thr 0.6 ``` +Example on NuScenes data using [BEVFusion model](https://drive.google.com/file/d/1QkvbYDk4G2d6SZoeJqish13qSyXA4lp3/view?usp=share_link): + +```shell +python demo/multi_modality_demo.py demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__LIDAR_TOP__1532402927647951.pcd.bin demo/data/nuscenes/ demo/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py ${CHECKPOINT_FILE} --cam-type all --score-thr 0.2 --show +``` + ### 3D Segmentation To test a 3D segmentor on point cloud data, simply run: diff --git a/docs/en/user_guides/visualization.md b/docs/en/user_guides/visualization.md index 9aa0c69fb3..3d51e77351 100644 --- a/docs/en/user_guides/visualization.md +++ b/docs/en/user_guides/visualization.md @@ -42,18 +42,19 @@ We support drawing 3D boxes on point cloud by using `draw_bboxes_3d`. ```python import torch +import numpy as np from mmdet3d.visualization import Det3DLocalVisualizer from mmdet3d.structures import LiDARInstance3DBoxes -points = np.fromfile('tests/data/kitti/training/velodyne/000000.bin', dtype=np.float32) +points = np.fromfile('demo/data/kitti/000008.bin', dtype=np.float32) points = points.reshape(-1, 4) visualizer = Det3DLocalVisualizer() # set point cloud in visualizer visualizer.set_points(points) -bboxes_3d = LiDARInstance3DBoxes(torch.tensor( - [[8.7314, -1.8559, -1.5997, 1.2000, 0.4800, 1.8900, - -1.5808]])), +bboxes_3d = LiDARInstance3DBoxes( + torch.tensor([[8.7314, -1.8559, -1.5997, 4.2000, 3.4800, 1.8900, + -1.5808]])) # Draw 3D bboxes visualizer.draw_bboxes_3d(bboxes_3d) visualizer.show() @@ -92,8 +93,6 @@ visualizer.draw_proj_bboxes_3d(gt_bboxes_3d, input_meta) visualizer.show() ``` -![mono3d](../../../resources/mono3d.png) - ### Drawing BEV Boxes We support drawing BEV boxes by using `draw_bev_bboxes`. @@ -120,23 +119,22 @@ visualizer.draw_bev_bboxes(gt_bboxes_3d, edge_colors='orange') visualizer.show() ``` - - ### Drawing 3D Semantic Mask We support draw segmentation mask via per-point colorization by using `draw_seg_mask`. ```python -import torch +import numpy as np from mmdet3d.visualization import Det3DLocalVisualizer -points = np.fromfile('tests/data/s3dis/points/Area_1_office_2.bin', dtype=np.float32) +points = np.fromfile('demo/data/sunrgbd/000017.bin', dtype=np.float32) points = points.reshape(-1, 3) visualizer = Det3DLocalVisualizer() mask = np.random.rand(points.shape[0], 3) points_with_mask = np.concatenate((points, mask), axis=-1) # Draw 3D points with mask +visualizer.set_points(points, pcd_mode=2, vis_mode='add') visualizer.draw_seg_mask(points_with_mask) visualizer.show() ``` @@ -168,10 +166,10 @@ This allows the inference and results generation to be done in remote server and We also provide scripts to visualize the dataset without inference. You can use `tools/misc/browse_dataset.py` to show loaded data and ground-truth online and save them on the disk. Currently we support single-modality 3D detection and 3D segmentation on all the datasets, multi-modality 3D detection on KITTI and SUN RGB-D, as well as monocular 3D detection on nuScenes. To browse the KITTI dataset, you can run the following command: ```shell -python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py --task det --output-dir ${OUTPUT_DIR} +python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py --task lidar_det --output-dir ${OUTPUT_DIR} ``` -**Notice**: Once specifying `--output-dir`, the images of views specified by users will be saved when pressing `_ESC_` in open3d window. +**Notice**: Once specifying `--output-dir`, the images of views specified by users will be saved when pressing `_ESC_` in open3d window. If you want to zoom out/in the point clouds to inspect more details, you could specify `--show-interval=0` in the command. To verify the data consistency and the effect of data augmentation, you can also add `--aug` flag to visualize the data after data augmentation using the command as below: @@ -182,7 +180,7 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py - If you also want to show 2D images with 3D bounding boxes projected onto them, you need to find a config that supports multi-modality data loading, and then change the `--task` args to `multi-modality_det`. An example is showed below: ```shell -python tools/misc/browse_dataset.py configs/mvxnet/dv_mvx-fpn_second_secfpn_adamw_2x8_80e_kitti-3d-3class.py --task multi-modality_det --output-dir ${OUTPUT_DIR} +python tools/misc/browse_dataset.py configs/mvxnet/mvxnet_fpn_dv_second_secfpn_8xb2-80e_kitti-3d-3class.py --task multi-modality_det --output-dir ${OUTPUT_DIR} ``` ![](../../../resources/browse_dataset_multi_modality.png) @@ -190,7 +188,7 @@ python tools/misc/browse_dataset.py configs/mvxnet/dv_mvx-fpn_second_secfpn_adam You can simply browse different datasets using different configs, e.g. visualizing the ScanNet dataset in 3D semantic segmentation task: ```shell -python tools/misc/browse_dataset.py configs/_base_/datasets/scannet-seg.py --task lidar_seg --output-dir ${OUTPUT_DIR} --online +python tools/misc/browse_dataset.py configs/_base_/datasets/scannet-seg.py --task lidar_seg --output-dir ${OUTPUT_DIR} ``` ![](../../../resources/browse_dataset_seg.png) @@ -198,7 +196,7 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/scannet-seg.py --tas And browsing the nuScenes dataset in monocular 3D detection task: ```shell -python tools/misc/browse_dataset.py configs/_base_/datasets/nus-mono3d.py --task mono_det --output-dir ${OUTPUT_DIR} --online +python tools/misc/browse_dataset.py configs/_base_/datasets/nus-mono3d.py --task mono_det --output-dir ${OUTPUT_DIR} ``` ![](../../../resources/browse_dataset_mono.png) diff --git a/docs/zh_cn/user_guides/inference.md b/docs/zh_cn/user_guides/inference.md index ceb2230791..a3a62b53f4 100644 --- a/docs/zh_cn/user_guides/inference.md +++ b/docs/zh_cn/user_guides/inference.md @@ -78,6 +78,12 @@ python demo/multi_modality_demo.py demo/data/kitti/000008.bin demo/data/kitti/00 python demo/multi_modality_demo.py demo/data/sunrgbd/000017.bin demo/data/sunrgbd/000017.jpg demo/data/sunrgbd/sunrgbd_000017_infos.pkl configs/imvotenet/imvotenet_stage2_8xb16_sunrgbd-3d.py ${CHECKPOINT_FILE} --cam-type CAM0 --show --score-thr 0.6 ``` +在 NuScenes 数据上测试 [BEVFusion 模型](https://drive.google.com/file/d/1QkvbYDk4G2d6SZoeJqish13qSyXA4lp3/view?usp=share_link) + +```shell +python demo/multi_modality_demo.py demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__LIDAR_TOP__1532402927647951.pcd.bin demo/data/nuscenes/ demo/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py ${CHECKPOINT_FILE} --cam-type all --score-thr 0.2 --show +``` + ### 3D 分割 在点云数据上测试 3D 分割器,运行: diff --git a/docs/zh_cn/user_guides/visualization.md b/docs/zh_cn/user_guides/visualization.md index 6623d7ed28..d09c45714b 100644 --- a/docs/zh_cn/user_guides/visualization.md +++ b/docs/zh_cn/user_guides/visualization.md @@ -42,18 +42,19 @@ visualizer.show() ```python import torch +import numpy as np from mmdet3d.visualization import Det3DLocalVisualizer from mmdet3d.structures import LiDARInstance3DBoxes -points = np.fromfile('tests/data/kitti/training/velodyne/000000.bin', dtype=np.float32) +points = np.fromfile('demo/data/kitti/000008.bin', dtype=np.float32) points = points.reshape(-1, 4) visualizer = Det3DLocalVisualizer() # set point cloud in visualizer visualizer.set_points(points) -bboxes_3d = LiDARInstance3DBoxes(torch.tensor( - [[8.7314, -1.8559, -1.5997, 1.2000, 0.4800, 1.8900, - -1.5808]])), +bboxes_3d = LiDARInstance3DBoxes( + torch.tensor([[8.7314, -1.8559, -1.5997, 4.2000, 3.4800, 1.8900, + -1.5808]])) # Draw 3D bboxes visualizer.draw_bboxes_3d(bboxes_3d) visualizer.show() @@ -92,8 +93,6 @@ visualizer.draw_proj_bboxes_3d(gt_bboxes_3d, input_meta) visualizer.show() ``` -![mono3d](../../../resources/mono3d.png) - ### 绘制 BEV 视角的框 通过使用 `draw_bev_bboxes`,我们支持绘制 BEV 视角下的框。 @@ -120,23 +119,22 @@ visualizer.draw_bev_bboxes(gt_bboxes_3d, edge_colors='orange') visualizer.show() ``` - - ### 绘制 3D 分割掩码 通过使用 `draw_seg_mask`,我们支持通过逐点着色来绘制分割掩码。 ```python -import torch +import numpy as np from mmdet3d.visualization import Det3DLocalVisualizer -points = np.fromfile('tests/data/s3dis/points/Area_1_office_2.bin', dtype=np.float32) +points = np.fromfile('demo/data/sunrgbd/000017.bin', dtype=np.float32) points = points.reshape(-1, 3) visualizer = Det3DLocalVisualizer() mask = np.random.rand(points.shape[0], 3) points_with_mask = np.concatenate((points, mask), axis=-1) # Draw 3D points with mask +visualizer.set_points(points, pcd_mode=2, vis_mode='add') visualizer.draw_seg_mask(points_with_mask) visualizer.show() ``` @@ -171,7 +169,7 @@ python tools/misc/visualize_results.py ${CONFIG_FILE} --result ${RESULTS_PATH} - python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py --task lidar_det --output-dir ${OUTPUT_DIR} ``` -**注意**:一旦指定了 `--output-dir`,当在 open3d 窗口中按下 `_ESC_` 时,用户指定的视图图像将会被保存下来。 +**注意**:一旦指定了 `--output-dir`,当在 open3d 窗口中按下 `_ESC_` 时,用户指定的视图图像将会被保存下来。如果你想要对点云进行缩放操作以观察更多细节, 你可以在命令中指定 `--show-interval=0`。 为了验证数据的一致性和数据增强的效果,你可以加上 `--aug` 来可视化数据增强后的数据,指令如下所示: @@ -182,7 +180,7 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/kitti-3d-3class.py - 如果你想显示带有投影的 3D 边界框的 2D 图像,你需要一个支持多模态数据加载的配置文件,并将 `--task` 参数改为 `multi-modality_det`。示例如下: ```shell -python tools/misc/browse_dataset.py configs/mvxnet/dv_mvx-fpn_second_secfpn_adamw_2x8_80e_kitti-3d-3class.py --task multi-modality_det --output-dir ${OUTPUT_DIR} +python tools/misc/browse_dataset.py configs/mvxnet/mvxnet_fpn_dv_second_secfpn_8xb2-80e_kitti-3d-3class.py --task multi-modality_det --output-dir ${OUTPUT_DIR} ``` ![](../../../resources/browse_dataset_multi_modality.png) @@ -190,7 +188,7 @@ python tools/misc/browse_dataset.py configs/mvxnet/dv_mvx-fpn_second_secfpn_adam 你可以使用不同的配置浏览不同的数据集,例如在 3D 语义分割任务中可视化 ScanNet 数据集: ```shell -python tools/misc/browse_dataset.py configs/_base_/datasets/scannet-seg.py --task lidar_seg --output-dir ${OUTPUT_DIR} --online +python tools/misc/browse_dataset.py configs/_base_/datasets/scannet-seg.py --task lidar_seg --output-dir ${OUTPUT_DIR} ``` ![](../../../resources/browse_dataset_seg.png) @@ -198,7 +196,7 @@ python tools/misc/browse_dataset.py configs/_base_/datasets/scannet-seg.py --tas 在单目 3D 检测任务中浏览 nuScenes 数据集: ```shell -python tools/misc/browse_dataset.py configs/_base_/datasets/nus-mono3d.py --task mono_det --output-dir ${OUTPUT_DIR} --online +python tools/misc/browse_dataset.py configs/_base_/datasets/nus-mono3d.py --task mono_det --output-dir ${OUTPUT_DIR} ``` ![](../../../resources/browse_dataset_mono.png) diff --git a/mmdet3d/apis/inference.py b/mmdet3d/apis/inference.py index 2ba6eed23c..c4d2f2efc1 100644 --- a/mmdet3d/apis/inference.py +++ b/mmdet3d/apis/inference.py @@ -188,10 +188,10 @@ def inference_multi_modality_detector(model: nn.Module, imgs (str, Sequence[str]): Either image files or loaded images. ann_file (str, Sequence[str]): Annotation files. - cam_type (str): Image of Camera chose to infer. - For kitti dataset, it should be 'CAM2', - and for nuscenes dataset, it should be - 'CAM_FRONT'. Defaults to 'CAM_FRONT'. + cam_type (str): Image of Camera chose to infer. When detector only uses + single-view image, we need to specify a camera view. For kitti + dataset, it should be 'CAM2'. For sunrgbd, it should be 'CAM0'. + When detector uses multi-view images, we should set it to 'all'. Returns: :obj:`Det3DDataSample` or list[:obj:`Det3DDataSample`]: @@ -220,37 +220,51 @@ def inference_multi_modality_detector(model: nn.Module, data = [] for index, pcd in enumerate(pcds): # get data info containing calib - img = imgs[index] data_info = data_list[index] - img_path = data_info['images'][cam_type]['img_path'] - - if osp.basename(img_path) != osp.basename(img): - raise ValueError(f'the info file of {img_path} is not provided.') + img = imgs[index] - # TODO: check the name consistency of - # image file and point cloud file - # TODO: support multi-view image loading - data_ = dict( - lidar_points=dict(lidar_path=pcd), - img_path=img, - box_type_3d=box_type_3d, - box_mode_3d=box_mode_3d) + if cam_type != 'all': + assert osp.isfile(img), f'{img} must be a file.' + img_path = data_info['images'][cam_type]['img_path'] + if osp.basename(img_path) != osp.basename(img): + raise ValueError( + f'the info file of {img_path} is not provided.') + data_ = dict( + lidar_points=dict(lidar_path=pcd), + img_path=img, + box_type_3d=box_type_3d, + box_mode_3d=box_mode_3d) + data_info['images'][cam_type]['img_path'] = img + if 'cam2img' in data_info['images'][cam_type]: + # The data annotation in SRUNRGBD dataset does not contain + # `cam2img` + data_['cam2img'] = np.array( + data_info['images'][cam_type]['cam2img']) + + # LiDAR to image conversion for KITTI dataset + if box_mode_3d == Box3DMode.LIDAR: + if 'lidar2img' in data_info['images'][cam_type]: + data_['lidar2img'] = np.array( + data_info['images'][cam_type]['lidar2img']) + # Depth to image conversion for SUNRGBD dataset + elif box_mode_3d == Box3DMode.DEPTH: + data_['depth2img'] = np.array( + data_info['images'][cam_type]['depth2img']) + else: + assert osp.isdir(img), f'{img} must be a file directory' + for _, img_info in data_info['images'].items(): + img_info['img_path'] = osp.join(img, img_info['img_path']) + assert osp.isfile(img_info['img_path'] + ), f'{img_info["img_path"]} does not exist.' + data_ = dict( + lidar_points=dict(lidar_path=pcd), + images=data_info['images'], + box_type_3d=box_type_3d, + box_mode_3d=box_mode_3d) - data_info['images'][cam_type]['img_path'] = img - if 'cam2img' in data_info['images'][cam_type]: - # The data annotation in SRUNRGBD dataset does not contain - # `cam2img` - data_['cam2img'] = np.array( - data_info['images'][cam_type]['cam2img']) - - # LiDAR to image conversion for KITTI dataset - if box_mode_3d == Box3DMode.LIDAR: - data_['lidar2img'] = np.array( - data_info['images'][cam_type]['lidar2img']) - # Depth to image conversion for SUNRGBD dataset - elif box_mode_3d == Box3DMode.DEPTH: - data_['depth2img'] = np.array( - data_info['images'][cam_type]['depth2img']) + if 'timestamp' in data_info: + # Using multi-sweeps need `timestamp` + data_['timestamp'] = data_info['timestamp'] data_ = test_pipeline(data_) data.append(data_) diff --git a/mmdet3d/engine/hooks/visualization_hook.py b/mmdet3d/engine/hooks/visualization_hook.py index 5fc46ba6e6..6c2d18d4c5 100644 --- a/mmdet3d/engine/hooks/visualization_hook.py +++ b/mmdet3d/engine/hooks/visualization_hook.py @@ -102,8 +102,17 @@ def after_val_iter(self, runner: Runner, batch_idx: int, data_batch: dict, ]: assert 'img_path' in outputs[0], 'img_path is not in outputs[0]' img_path = outputs[0].img_path - img_bytes = get(img_path, backend_args=self.backend_args) - img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + if isinstance(img_path, list): + img = [] + for single_img_path in img_path: + img_bytes = get( + single_img_path, backend_args=self.backend_args) + single_img = mmcv.imfrombytes( + img_bytes, channel_order='rgb') + img.append(single_img) + else: + img_bytes = get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') data_input['img'] = img if self.vis_task in ['lidar_det', 'multi-modality_det', 'lidar_seg']: @@ -161,10 +170,21 @@ def after_test_iter(self, runner: Runner, batch_idx: int, data_batch: dict, assert 'img_path' in data_sample, \ 'img_path is not in data_sample' img_path = data_sample.img_path - img_bytes = get(img_path, backend_args=self.backend_args) - img = mmcv.imfrombytes(img_bytes, channel_order='rgb') + if isinstance(img_path, list): + img = [] + for single_img_path in img_path: + img_bytes = get( + single_img_path, backend_args=self.backend_args) + single_img = mmcv.imfrombytes( + img_bytes, channel_order='rgb') + img.append(single_img) + else: + img_bytes = get(img_path, backend_args=self.backend_args) + img = mmcv.imfrombytes(img_bytes, channel_order='rgb') data_input['img'] = img if self.test_out_dir is not None: + if isinstance(img_path, list): + img_path = img_path[0] out_file = osp.basename(img_path) out_file = osp.join(self.test_out_dir, out_file) diff --git a/mmdet3d/visualization/local_visualizer.py b/mmdet3d/visualization/local_visualizer.py index 472bc98ecd..231edb76ba 100644 --- a/mmdet3d/visualization/local_visualizer.py +++ b/mmdet3d/visualization/local_visualizer.py @@ -1,6 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy -from typing import List, Optional, Tuple, Union +import math +import time +from typing import List, Optional, Sequence, Tuple, Union import matplotlib.pyplot as plt import mmcv @@ -67,6 +69,8 @@ class Det3DLocalVisualizer(DetLocalVisualizer): Defaults to dict(size=1, origin=[0, 0, 0]). alpha (int or float): The transparency of bboxes or mask. Defaults to 0.8. + multi_imgs_col (int): The number of columns in arrangement when showing + multi-view images. Examples: >>> import numpy as np @@ -102,19 +106,23 @@ class Det3DLocalVisualizer(DetLocalVisualizer): ... vis_task='lidar_seg') """ - def __init__(self, - name: str = 'visualizer', - points: Optional[np.ndarray] = None, - image: Optional[np.ndarray] = None, - pcd_mode: int = 0, - vis_backends: Optional[List[dict]] = None, - save_dir: Optional[str] = None, - bbox_color: Optional[Union[str, Tuple[int]]] = None, - text_color: Union[str, Tuple[int]] = (200, 200, 200), - mask_color: Optional[Union[str, Tuple[int]]] = None, - line_width: Union[int, float] = 3, - frame_cfg: dict = dict(size=1, origin=[0, 0, 0]), - alpha: Union[int, float] = 0.8) -> None: + def __init__( + self, + name: str = 'visualizer', + points: Optional[np.ndarray] = None, + image: Optional[np.ndarray] = None, + pcd_mode: int = 0, + vis_backends: Optional[List[dict]] = None, + save_dir: Optional[str] = None, + bbox_color: Optional[Union[str, Tuple[int]]] = None, + text_color: Union[str, Tuple[int]] = (200, 200, 200), + mask_color: Optional[Union[str, Tuple[int]]] = None, + line_width: Union[int, float] = 3, + frame_cfg: dict = dict(size=1, origin=[0, 0, 0]), + alpha: Union[int, float] = 0.8, + multi_imgs_col: int = 3, + fig_show_cfg: dict = dict(figsize=(18, 12)) + ) -> None: super().__init__( name=name, image=image, @@ -128,6 +136,8 @@ def __init__(self, if points is not None: self.set_points(points, pcd_mode=pcd_mode, frame_cfg=frame_cfg) self.pts_seg_num = 0 + self.multi_imgs_col = multi_imgs_col + self.fig_show_cfg.update(fig_show_cfg) def _clear_o3d_vis(self) -> None: """Clear open3d vis.""" @@ -163,7 +173,7 @@ def set_points(self, pcd_mode: int = 0, vis_mode: str = 'replace', frame_cfg: dict = dict(size=1, origin=[0, 0, 0]), - points_color: Tuple[float] = (0.5, 0.5, 0.5), + points_color: Tuple[float] = (1, 1, 1), points_size: int = 2, mode: str = 'xyz') -> None: """Set the point cloud to draw. @@ -183,7 +193,7 @@ def set_points(self, visualization initialization. Defaults to dict(size=1, origin=[0, 0, 0]). points_color (Tuple[float]): The color of points. - Defaults to (0.5, 0.5, 0.5). + Defaults to (1, 1, 1). points_size (int): The size of points to show on visualizer. Defaults to 2. mode (str): Indicate type of the input points, available mode @@ -204,8 +214,10 @@ def set_points(self, self.o3d_vis.remove_geometry(self.pcd) # set points size in Open3D - if self.o3d_vis.get_render_option() is not None: - self.o3d_vis.get_render_option().point_size = points_size + render_option = self.o3d_vis.get_render_option() + if render_option is not None: + render_option.point_size = points_size + render_option.background_color = np.asarray([0, 0, 0]) points = points.copy() pcd = geometry.PointCloud() @@ -412,7 +424,8 @@ def draw_bev_bboxes(self, def draw_points_on_image(self, points: Union[np.ndarray, Tensor], pts2img: np.ndarray, - sizes: Union[np.ndarray, int] = 10) -> None: + sizes: Union[np.ndarray, int] = 3, + max_depth: Optional[float] = None) -> None: """Draw projected points on the image. Args: @@ -420,13 +433,18 @@ def draw_points_on_image(self, pts2img (np.ndarray): The transformation matrix from the coordinate of point cloud to image plane. sizes (np.ndarray or int): The marker size. Defaults to 10. + max_depth (float): The max depth in the color map. Defaults to + None. """ check_type('points', points, (np.ndarray, Tensor)) points = tensor2ndarray(points) assert self._image is not None, 'Please set image using `set_image`' projected_points = points_cam2img(points, pts2img, with_depth=True) depths = projected_points[:, 2] - colors = (depths % 20) / 20 + # Show depth adaptively consideing different scenes + if max_depth is None: + max_depth = depths.max() + colors = (depths % max_depth) / max_depth # use colormap to obtain the render color color_map = plt.get_cmap('jet') self.ax_save.scatter( @@ -435,7 +453,7 @@ def draw_points_on_image(self, c=colors, cmap=color_map, s=sizes, - alpha=0.5, + alpha=0.7, edgecolors='none') # TODO: set bbox color according to palette @@ -450,7 +468,8 @@ def draw_proj_bboxes_3d( line_widths: Union[int, float, List[Union[int, float]]] = 2, face_colors: Union[str, Tuple[int], List[Union[str, Tuple[int]]]] = 'royalblue', - alpha: Union[int, float] = 0.4): + alpha: Union[int, float] = 0.4, + img_size: Optional[Tuple] = None): """Draw projected 3D boxes on the image. Args: @@ -476,6 +495,7 @@ def draw_proj_bboxes_3d( face_colors (str or Tuple[int] or List[str or Tuple[int]]): The face colors. Defaults to 'royalblue'. alpha (int or float): The transparency of bboxes. Defaults to 0.4. + img_size (tuple, optional): The size (w, h) of the image. """ check_type('bboxes', bboxes_3d, BaseInstance3DBoxes) @@ -490,6 +510,14 @@ def draw_proj_bboxes_3d( raise NotImplementedError('unsupported box type!') corners_2d = proj_bbox3d_to_img(bboxes_3d, input_meta) + if img_size is not None: + # Filter out the bbox where half of stuff is outside the image. + # This is for the visualization of multi-view image. + valid_point_idx = (corners_2d[..., 0] >= 0) & \ + (corners_2d[..., 0] <= img_size[0]) & \ + (corners_2d[..., 1] >= 0) & (corners_2d[..., 1] <= img_size[1]) # noqa: E501 + valid_bbox_idx = valid_point_idx.sum(axis=-1) >= 4 + corners_2d = corners_2d[valid_bbox_idx] lines_verts_idx = [0, 1, 2, 3, 7, 6, 5, 4, 0, 3, 7, 4, 5, 1, 2, 6] lines_verts = corners_2d[:, lines_verts_idx, :] @@ -593,16 +621,59 @@ def _draw_instances_3d(self, if vis_task in ['mono_det', 'multi-modality_det']: assert 'img' in data_input img = data_input['img'] - if isinstance(data_input['img'], Tensor): - img = img.permute(1, 2, 0).numpy() - img = img[..., [2, 1, 0]] # bgr to rgb - self.set_image(img) - self.draw_proj_bboxes_3d(bboxes_3d, input_meta) - if vis_task == 'mono_det' and hasattr(instances, 'centers_2d'): - centers_2d = instances.centers_2d - self.draw_points(centers_2d) - drawn_img = self.get_image() - data_3d['img'] = drawn_img + if isinstance(img, list) or (isinstance(img, (np.ndarray, Tensor)) + and len(img.shape) == 4): + # show multi-view images + img_size = img[0].shape[:2] if isinstance( + img, list) else img.shape[-2:] # noqa: E501 + img_col = self.multi_imgs_col + img_row = math.ceil(len(img) / img_col) + composed_img = np.zeros( + (img_size[0] * img_row, img_size[1] * img_col, 3), + dtype=np.uint8) + for i, single_img in enumerate(img): + # Note that we should keep the same order of elements both + # in `img` and `input_meta` + if isinstance(single_img, Tensor): + single_img = single_img.permute(1, 2, 0).numpy() + single_img = single_img[..., [2, 1, 0]] # bgr to rgb + self.set_image(single_img) + single_img_meta = dict() + for key, meta in input_meta.items(): + if isinstance(meta, + (Sequence, np.ndarray, + Tensor)) and len(meta) == len(img): + single_img_meta[key] = meta[i] + else: + single_img_meta[key] = meta + self.draw_proj_bboxes_3d( + bboxes_3d, + single_img_meta, + img_size=single_img.shape[:2][::-1]) + if vis_task == 'mono_det' and hasattr( + instances, 'centers_2d'): + centers_2d = instances.centers_2d + self.draw_points(centers_2d) + composed_img[(i // img_col) * + img_size[0]:(i // img_col + 1) * img_size[0], + (i % img_col) * + img_size[1]:(i % img_col + 1) * + img_size[1]] = self.get_image() + data_3d['img'] = composed_img + else: + # show single-view image + # TODO: Solve the problem: some line segments of 3d bboxes are + # out of image by a large margin + if isinstance(data_input['img'], Tensor): + img = img.permute(1, 2, 0).numpy() + img = img[..., [2, 1, 0]] # bgr to rgb + self.set_image(img) + self.draw_proj_bboxes_3d(bboxes_3d, input_meta) + if vis_task == 'mono_det' and hasattr(instances, 'centers_2d'): + centers_2d = instances.centers_2d + self.draw_points(centers_2d) + drawn_img = self.get_image() + data_3d['img'] = drawn_img return data_3d @@ -644,7 +715,8 @@ def show(self, drawn_img: Optional[np.ndarray] = None, win_name: str = 'image', wait_time: int = 0, - continue_key: str = ' ') -> None: + continue_key: str = ' ', + vis_task: str = 'lidar_det') -> None: """Show the drawn point cloud/image. Args: @@ -661,22 +733,46 @@ def show(self, means "forever". Defaults to 0. continue_key (str): The key for users to continue. Defaults to ' '. """ + if vis_task == 'multi-modality_det': + img_wait_time = 0.5 + else: + img_wait_time = wait_time + + # In order to show multi-modal results at the same time, we show image + # firstly and then show point cloud since the running of + # Open3D will block the process + if hasattr(self, '_image'): + if drawn_img is None and drawn_img_3d is None: + # use the image got by Visualizer.get_image() + super().show(drawn_img_3d, win_name, img_wait_time, + continue_key) + else: + if drawn_img_3d is not None: + super().show(drawn_img_3d, win_name, img_wait_time, + continue_key) + if drawn_img is not None: + super().show(drawn_img, win_name, img_wait_time, + continue_key) + if hasattr(self, 'o3d_vis'): - self.o3d_vis.run() + self.o3d_vis.poll_events() + self.o3d_vis.update_renderer() + if wait_time > 0: + time.sleep(wait_time) + else: + self.o3d_vis.run() if save_path is not None: if not (save_path.endswith('.png') or save_path.endswith('.jpg')): save_path += '.png' self.o3d_vis.capture_screen_image(save_path) + + # TODO: support more flexible window control + self.o3d_vis.clear_geometries() self.o3d_vis.destroy_window() + self.o3d_vis.close() self._clear_o3d_vis() - if hasattr(self, '_image'): - if drawn_img_3d is not None: - super().show(drawn_img_3d, win_name, wait_time, continue_key) - if drawn_img is not None: - super().show(drawn_img, win_name, wait_time, continue_key) - # TODO: Support Visualize the 3D results from image and point cloud # respectively @master_only @@ -823,7 +919,8 @@ def add_datasample(self, drawn_img_3d, drawn_img, win_name=name, - wait_time=wait_time) + wait_time=wait_time, + vis_task=vis_task) if out_file is not None: # check the suffix of the name of image file diff --git a/projects/BEVFusion/README.md b/projects/BEVFusion/README.md index 101f7b35a1..ee4df1045c 100644 --- a/projects/BEVFusion/README.md +++ b/projects/BEVFusion/README.md @@ -29,6 +29,14 @@ We implement BEVFusion and provide the results and pretrained checkpoints on NuS python projects/BEVFusion/setup.py develop ``` +### Demo + +Run a demo on NuScenes data using [BEVFusion model](https://drive.google.com/file/d/1QkvbYDk4G2d6SZoeJqish13qSyXA4lp3/view?usp=share_link): + +```shell +python demo/multi_modality_demo.py demo/data/nuscenes/n015-2018-07-24-11-22-45+0800__LIDAR_TOP__1532402927647951.pcd.bin demo/data/nuscenes/ demo/data/nuscenes/n015-2018-07-24-11-22-45+0800.pkl projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py ${CHECKPOINT_FILE} --cam-type all --score-thr 0.2 --show +``` + ### Training commands In MMDetection3D's root directory, run the following command to train the model: diff --git a/projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py b/projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py index 8f12892372..16f6752ee8 100644 --- a/projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py +++ b/projects/BEVFusion/configs/bevfusion_voxel0075_second_secfpn_8xb4-cyclic-20e_nus-3d.py @@ -1,4 +1,4 @@ -_base_ = ['mmdet3d::_base_/default_runtime.py'] +_base_ = ['../../../configs/_base_/default_runtime.py'] custom_imports = dict( imports=['projects.BEVFusion.bevfusion'], allow_failed_imports=False) @@ -323,7 +323,7 @@ meta_keys=[ 'cam2img', 'ori_cam2img', 'lidar2cam', 'lidar2img', 'cam2lidar', 'ori_lidar2img', 'img_aug_matrix', 'box_type_3d', 'sample_idx', - 'lidar_path', 'img_path' + 'lidar_path', 'img_path', 'num_pts_feats', 'num_views' ]) ] diff --git a/tools/misc/browse_dataset.py b/tools/misc/browse_dataset.py index 3381b516bd..164800952c 100644 --- a/tools/misc/browse_dataset.py +++ b/tools/misc/browse_dataset.py @@ -67,6 +67,9 @@ def build_data_cfg(config_path, aug, cfg_options): # use only first dataset for `ConcatDataset` if cfg.train_dataloader.dataset['type'] == 'ConcatDataset': cfg.train_dataloader.dataset = cfg.train_dataloader.dataset.datasets[0] + if cfg.train_dataloader.dataset['type'] == 'CBGSDataset': + cfg.train_dataloader.dataset = cfg.train_dataloader.dataset.dataset + train_data_cfg = cfg.train_dataloader.dataset if aug: