Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiangxu-0103 committed Mar 8, 2023
1 parent cd4b6b2 commit 9629d0d
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 307 deletions.
3 changes: 1 addition & 2 deletions mmdet3d/apis/inferencers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base_det3d_inferencer import BaseDet3DInferencer
from .base_seg3d_inferencer import BaseSeg3DInferencer
from .lidar_det3d_inferencer import LidarDet3DInferencer
from .lidar_seg3d_inferencer import LidarSeg3DInferencer
from .mono_det3d_inferencer import MonoDet3DInferencer

__all__ = [
'BaseDet3DInferencer', 'MonoDet3DInferencer', 'LidarDet3DInferencer',
'BaseSeg3DInferencer', 'LidarSeg3DInferencer'
'LidarSeg3DInferencer'
]
19 changes: 13 additions & 6 deletions mmdet3d/apis/inferencers/base_det3d_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,11 +295,18 @@ def pred2dict(self, data_sample: InstanceData) -> Dict:
It's better to contain only basic data elements such as strings and
numbers in order to guarantee it's json-serializable.
"""
pred_instances = data_sample.pred_instances_3d.numpy()
result = {
'bboxes_3d': pred_instances.bboxes_3d.tensor.cpu().tolist(),
'labels_3d': pred_instances.labels_3d.tolist(),
'scores_3d': pred_instances.scores_3d.tolist()
}
result = {}
if 'pred_instances_3d' in data_sample:
pred_instances_3d = data_sample.pred_instances_3d.numpy()
result = {
'bboxes_3d': pred_instances_3d.bboxes_3d.tensor.cpu().tolist(),
'labels_3d': pred_instances_3d.labels_3d.tolist(),
'scores_3d': pred_instances_3d.scores_3d.tolist()
}

if 'pred_pts_seg' in data_sample:
pred_pts_seg = data_sample.pred_pts_seg.numpy()
result['pts_semantic_mask'] = \
pred_pts_seg.pts_semantic_mask.tolist()

return result
296 changes: 0 additions & 296 deletions mmdet3d/apis/inferencers/base_seg3d_inferencer.py

This file was deleted.

7 changes: 4 additions & 3 deletions mmdet3d/apis/inferencers/lidar_seg3d_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from mmdet3d.registry import INFERENCERS
from mmdet3d.utils import ConfigType
from .base_seg3d_inferencer import BaseSeg3DInferencer
from .base_det3d_inferencer import BaseDet3DInferencer

InstanceList = List[InstanceData]
InputType = Union[str, np.ndarray]
Expand All @@ -22,7 +22,7 @@

@INFERENCERS.register_module(name='seg3d-lidar')
@INFERENCERS.register_module()
class LidarSeg3DInferencer(BaseSeg3DInferencer):
class LidarSeg3DInferencer(BaseDet3DInferencer):
"""The inferencer of LiDAR-based segmentation.
Args:
Expand All @@ -46,7 +46,8 @@ class LidarSeg3DInferencer(BaseSeg3DInferencer):
preprocess_kwargs: set = set()
forward_kwargs: set = set()
visualize_kwargs: set = {
'return_vis', 'show', 'wait_time', 'draw_pred', 'img_out_dir'
'return_vis', 'show', 'wait_time', 'draw_pred', 'pred_score_thr',
'img_out_dir'
}
postprocess_kwargs: set = {
'print_result', 'pred_out_file', 'return_datasample'
Expand Down

0 comments on commit 9629d0d

Please sign in to comment.