Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance]: Add typehints for dataset transforms and fix potential bug for PointSample #1875

Merged
merged 6 commits into from
Oct 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mmdet3d/datasets/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __call__(self, data):
data (dict): A result dict contains the data to transform.

Returns:
dict: Transformed data.
dict: Transformed data.
"""

for t in self.transforms:
Expand Down
86 changes: 46 additions & 40 deletions mmdet3d/datasets/transforms/dbsampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import os
import warnings
from typing import List, Optional

import mmengine
import numpy as np
Expand All @@ -16,18 +16,19 @@ class BatchSampler:

Args:
sample_list (list[dict]): List of samples.
name (str, optional): The category of samples. Default: None.
epoch (int, optional): Sampling epoch. Default: None.
shuffle (bool, optional): Whether to shuffle indices. Default: False.
drop_reminder (bool, optional): Drop reminder. Default: False.
name (str, optional): The category of samples. Defaults to None.
epoch (int, optional): Sampling epoch. Defaults to None.
shuffle (bool, optional): Whether to shuffle indices.
Defaults to False.
drop_reminder (bool, optional): Drop reminder. Defaults to False.
"""

def __init__(self,
sampled_list,
name=None,
epoch=None,
shuffle=True,
drop_reminder=False):
sampled_list: List[dict],
name: Optional[str] = None,
epoch: Optional[int] = None,
shuffle: bool = True,
drop_reminder: bool = False) -> None:
self._sampled_list = sampled_list
self._indices = np.arange(len(sampled_list))
if shuffle:
Expand All @@ -40,7 +41,7 @@ def __init__(self,
self._epoch_counter = 0
self._drop_reminder = drop_reminder

def _sample(self, num):
def _sample(self, num: int) -> List[int]:
"""Sample specific number of ground truths and return indices.

Args:
Expand All @@ -57,15 +58,15 @@ def _sample(self, num):
self._idx += num
return ret

def _reset(self):
def _reset(self) -> None:
"""Reset the index of batchsampler to zero."""
assert self._name is not None
# print("reset", self._name)
if self._shuffle:
np.random.shuffle(self._indices)
self._idx = 0

def sample(self, num):
def sample(self, num: int) -> List[dict]:
"""Sample specific number of ground truths.

Args:
Expand All @@ -88,24 +89,28 @@ class DataBaseSampler(object):
rate (float): Rate of actual sampled over maximum sampled number.
prepare (dict): Name of preparation functions and the input value.
sample_groups (dict): Sampled classes and numbers.
classes (list[str], optional): List of classes. Default: None.
points_loader(dict, optional): Config of points loader. Default:
dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0,1,2,3])
classes (list[str], optional): List of classes. Defaults to None.
points_loader(dict, optional): Config of points loader. Defaults to
dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0, 1, 2, 3]).
file_client_args (dict, optional): Config dict of file clients,
refer to
https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py
for more details. Defaults to dict(backend='disk').
"""

def __init__(self,
info_path,
data_root,
rate,
prepare,
sample_groups,
classes=None,
points_loader=dict(
info_path: str,
data_root: str,
rate: float,
prepare: dict,
sample_groups: dict,
classes: Optional[List[str]] = None,
points_loader: dict = dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=[0, 1, 2, 3]),
file_client_args=dict(backend='disk')):
file_client_args: dict = dict(backend='disk')) -> None:
super().__init__()
self.data_root = data_root
self.info_path = info_path
Expand All @@ -118,18 +123,9 @@ def __init__(self,
self.file_client = mmengine.FileClient(**file_client_args)

# load data base infos
if hasattr(self.file_client, 'get_local_path'):
with self.file_client.get_local_path(info_path) as local_path:
# loading data from a file-like object needs file format
db_infos = mmengine.load(
open(local_path, 'rb'), file_format='pkl')
else:
warnings.warn(
'The used MMCV version does not have get_local_path. '
f'We treat the {info_path} as local paths and it '
'might cause errors if the path is not a local path. '
'Please use MMCV>= 1.3.16 if you meet errors.')
db_infos = mmengine.load(info_path)
with self.file_client.get_local_path(info_path) as local_path:
# loading data from a file-like object needs file format
db_infos = mmengine.load(open(local_path, 'rb'), file_format='pkl')

# filter database infos
from mmengine.logging import MMLogger
Expand Down Expand Up @@ -163,7 +159,7 @@ def __init__(self,
# TODO: No group_sampling currently

@staticmethod
def filter_by_difficulty(db_infos, removed_difficulty):
def filter_by_difficulty(db_infos: dict, removed_difficulty: list) -> dict:
"""Filter ground truths by difficulties.

Args:
Expand All @@ -182,7 +178,7 @@ def filter_by_difficulty(db_infos, removed_difficulty):
return new_db_infos

@staticmethod
def filter_by_min_points(db_infos, min_gt_points_dict):
def filter_by_min_points(db_infos: dict, min_gt_points_dict: dict) -> dict:
"""Filter ground truths by number of points in the bbox.

Args:
Expand All @@ -203,12 +199,19 @@ def filter_by_min_points(db_infos, min_gt_points_dict):
db_infos[name] = filtered_infos
return db_infos

def sample_all(self, gt_bboxes, gt_labels, img=None, ground_plane=None):
def sample_all(self,
gt_bboxes: np.ndarray,
gt_labels: np.ndarray,
img: Optional[np.ndarray] = None,
ground_plane: Optional[np.ndarray] = None) -> dict:
"""Sampling all categories of bboxes.

Args:
gt_bboxes (np.ndarray): Ground truth bounding boxes.
gt_labels (np.ndarray): Ground truth labels of boxes.
img (np.ndarray, optional): Image array. Defaults to None.
ground_plane (np.ndarray, optional): Ground plane information.
Defaults to None.

Returns:
dict: Dict of sampled 'pseudo ground truths'.
Expand Down Expand Up @@ -301,7 +304,10 @@ def sample_all(self, gt_bboxes, gt_labels, img=None, ground_plane=None):

return ret

def sample_class_v2(self, name, num, gt_bboxes):
def sample_class_v2(self,
name: str,
num: int,
gt_bboxes: np.ndarray) -> List[dict]:
"""Sampling specific categories of bounding boxes.

Args:
Expand Down
27 changes: 14 additions & 13 deletions mmdet3d/datasets/transforms/formating.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,16 @@ class Pack3DDetInputs(BaseTransform):

def __init__(
self,
keys: dict,
meta_keys: dict = ('img_path', 'ori_shape', 'img_shape', 'lidar2img',
'depth2img', 'cam2img', 'pad_shape', 'scale_factor',
'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip',
'box_mode_3d', 'box_type_3d', 'img_norm_cfg',
'num_pts_feats', 'pcd_trans', 'sample_idx',
'pcd_scale_factor', 'pcd_rotation',
'pcd_rotation_angle', 'lidar_path',
'transformation_3d_flow', 'trans_mat',
'affine_aug')):
keys: tuple,
meta_keys: tuple = ('img_path', 'ori_shape', 'img_shape', 'lidar2img',
'depth2img', 'cam2img', 'pad_shape',
'scale_factor', 'flip', 'pcd_horizontal_flip',
'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d',
'img_norm_cfg', 'num_pts_feats', 'pcd_trans',
'sample_idx', 'pcd_scale_factor', 'pcd_rotation',
'pcd_rotation_angle', 'lidar_path',
'transformation_3d_flow', 'trans_mat',
'affine_aug')) -> None:
self.keys = keys
self.meta_keys = meta_keys

Expand All @@ -99,7 +99,7 @@ def transform(self, results: Union[dict,
- img

- 'data_samples' (obj:`Det3DDataSample`): The annotation info of
the sample.
the sample.
"""
# augtest
if isinstance(results, list):
Expand All @@ -116,7 +116,7 @@ def transform(self, results: Union[dict,
else:
raise NotImplementedError

def pack_single_results(self, results):
def pack_single_results(self, results: dict) -> dict:
"""Method to pack the single input data. when the value in this dict is
a list, it usually is in Augmentations Testing.

Expand All @@ -132,7 +132,7 @@ def pack_single_results(self, results):
- points
- img

- 'data_samples' (obj:`Det3DDataSample`): The annotation info
- 'data_samples' (:obj:`Det3DDataSample`): The annotation info
of the sample.
"""
# Format 3D data
Expand Down Expand Up @@ -220,6 +220,7 @@ def pack_single_results(self, results):
return packed_results

def __repr__(self) -> str:
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(keys={self.keys})'
repr_str += f'(meta_keys={self.meta_keys})'
Expand Down
Loading