diff --git a/mmdet3d/datasets/transforms/compose.py b/mmdet3d/datasets/transforms/compose.py index 669bb84f8d..c35b4d1746 100644 --- a/mmdet3d/datasets/transforms/compose.py +++ b/mmdet3d/datasets/transforms/compose.py @@ -32,7 +32,7 @@ def __call__(self, data): data (dict): A result dict contains the data to transform. Returns: - dict: Transformed data. + dict: Transformed data. """ for t in self.transforms: diff --git a/mmdet3d/datasets/transforms/dbsampler.py b/mmdet3d/datasets/transforms/dbsampler.py index aaa24e22c8..a9753fcc0d 100644 --- a/mmdet3d/datasets/transforms/dbsampler.py +++ b/mmdet3d/datasets/transforms/dbsampler.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import os -import warnings +from typing import List, Optional import mmengine import numpy as np @@ -16,18 +16,19 @@ class BatchSampler: Args: sample_list (list[dict]): List of samples. - name (str, optional): The category of samples. Default: None. - epoch (int, optional): Sampling epoch. Default: None. - shuffle (bool, optional): Whether to shuffle indices. Default: False. - drop_reminder (bool, optional): Drop reminder. Default: False. + name (str, optional): The category of samples. Defaults to None. + epoch (int, optional): Sampling epoch. Defaults to None. + shuffle (bool, optional): Whether to shuffle indices. + Defaults to False. + drop_reminder (bool, optional): Drop reminder. Defaults to False. """ def __init__(self, - sampled_list, - name=None, - epoch=None, - shuffle=True, - drop_reminder=False): + sampled_list: List[dict], + name: Optional[str] = None, + epoch: Optional[int] = None, + shuffle: bool = True, + drop_reminder: bool = False) -> None: self._sampled_list = sampled_list self._indices = np.arange(len(sampled_list)) if shuffle: @@ -40,7 +41,7 @@ def __init__(self, self._epoch_counter = 0 self._drop_reminder = drop_reminder - def _sample(self, num): + def _sample(self, num: int) -> List[int]: """Sample specific number of ground truths and return indices. Args: @@ -57,7 +58,7 @@ def _sample(self, num): self._idx += num return ret - def _reset(self): + def _reset(self) -> None: """Reset the index of batchsampler to zero.""" assert self._name is not None # print("reset", self._name) @@ -65,7 +66,7 @@ def _reset(self): np.random.shuffle(self._indices) self._idx = 0 - def sample(self, num): + def sample(self, num: int) -> List[dict]: """Sample specific number of ground truths. Args: @@ -88,24 +89,28 @@ class DataBaseSampler(object): rate (float): Rate of actual sampled over maximum sampled number. prepare (dict): Name of preparation functions and the input value. sample_groups (dict): Sampled classes and numbers. - classes (list[str], optional): List of classes. Default: None. - points_loader(dict, optional): Config of points loader. Default: - dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0,1,2,3]) + classes (list[str], optional): List of classes. Defaults to None. + points_loader(dict, optional): Config of points loader. Defaults to + dict(type='LoadPointsFromFile', load_dim=4, use_dim=[0, 1, 2, 3]). + file_client_args (dict, optional): Config dict of file clients, + refer to + https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py + for more details. Defaults to dict(backend='disk'). """ def __init__(self, - info_path, - data_root, - rate, - prepare, - sample_groups, - classes=None, - points_loader=dict( + info_path: str, + data_root: str, + rate: float, + prepare: dict, + sample_groups: dict, + classes: Optional[List[str]] = None, + points_loader: dict = dict( type='LoadPointsFromFile', coord_type='LIDAR', load_dim=4, use_dim=[0, 1, 2, 3]), - file_client_args=dict(backend='disk')): + file_client_args: dict = dict(backend='disk')) -> None: super().__init__() self.data_root = data_root self.info_path = info_path @@ -118,18 +123,9 @@ def __init__(self, self.file_client = mmengine.FileClient(**file_client_args) # load data base infos - if hasattr(self.file_client, 'get_local_path'): - with self.file_client.get_local_path(info_path) as local_path: - # loading data from a file-like object needs file format - db_infos = mmengine.load( - open(local_path, 'rb'), file_format='pkl') - else: - warnings.warn( - 'The used MMCV version does not have get_local_path. ' - f'We treat the {info_path} as local paths and it ' - 'might cause errors if the path is not a local path. ' - 'Please use MMCV>= 1.3.16 if you meet errors.') - db_infos = mmengine.load(info_path) + with self.file_client.get_local_path(info_path) as local_path: + # loading data from a file-like object needs file format + db_infos = mmengine.load(open(local_path, 'rb'), file_format='pkl') # filter database infos from mmengine.logging import MMLogger @@ -163,7 +159,7 @@ def __init__(self, # TODO: No group_sampling currently @staticmethod - def filter_by_difficulty(db_infos, removed_difficulty): + def filter_by_difficulty(db_infos: dict, removed_difficulty: list) -> dict: """Filter ground truths by difficulties. Args: @@ -182,7 +178,7 @@ def filter_by_difficulty(db_infos, removed_difficulty): return new_db_infos @staticmethod - def filter_by_min_points(db_infos, min_gt_points_dict): + def filter_by_min_points(db_infos: dict, min_gt_points_dict: dict) -> dict: """Filter ground truths by number of points in the bbox. Args: @@ -203,12 +199,19 @@ def filter_by_min_points(db_infos, min_gt_points_dict): db_infos[name] = filtered_infos return db_infos - def sample_all(self, gt_bboxes, gt_labels, img=None, ground_plane=None): + def sample_all(self, + gt_bboxes: np.ndarray, + gt_labels: np.ndarray, + img: Optional[np.ndarray] = None, + ground_plane: Optional[np.ndarray] = None) -> dict: """Sampling all categories of bboxes. Args: gt_bboxes (np.ndarray): Ground truth bounding boxes. gt_labels (np.ndarray): Ground truth labels of boxes. + img (np.ndarray, optional): Image array. Defaults to None. + ground_plane (np.ndarray, optional): Ground plane information. + Defaults to None. Returns: dict: Dict of sampled 'pseudo ground truths'. @@ -301,7 +304,10 @@ def sample_all(self, gt_bboxes, gt_labels, img=None, ground_plane=None): return ret - def sample_class_v2(self, name, num, gt_bboxes): + def sample_class_v2(self, + name: str, + num: int, + gt_bboxes: np.ndarray) -> List[dict]: """Sampling specific categories of bounding boxes. Args: diff --git a/mmdet3d/datasets/transforms/formating.py b/mmdet3d/datasets/transforms/formating.py index 7e776d4aaf..6ac8014ce1 100644 --- a/mmdet3d/datasets/transforms/formating.py +++ b/mmdet3d/datasets/transforms/formating.py @@ -63,16 +63,16 @@ class Pack3DDetInputs(BaseTransform): def __init__( self, - keys: dict, - meta_keys: dict = ('img_path', 'ori_shape', 'img_shape', 'lidar2img', - 'depth2img', 'cam2img', 'pad_shape', 'scale_factor', - 'flip', 'pcd_horizontal_flip', 'pcd_vertical_flip', - 'box_mode_3d', 'box_type_3d', 'img_norm_cfg', - 'num_pts_feats', 'pcd_trans', 'sample_idx', - 'pcd_scale_factor', 'pcd_rotation', - 'pcd_rotation_angle', 'lidar_path', - 'transformation_3d_flow', 'trans_mat', - 'affine_aug')): + keys: tuple, + meta_keys: tuple = ('img_path', 'ori_shape', 'img_shape', 'lidar2img', + 'depth2img', 'cam2img', 'pad_shape', + 'scale_factor', 'flip', 'pcd_horizontal_flip', + 'pcd_vertical_flip', 'box_mode_3d', 'box_type_3d', + 'img_norm_cfg', 'num_pts_feats', 'pcd_trans', + 'sample_idx', 'pcd_scale_factor', 'pcd_rotation', + 'pcd_rotation_angle', 'lidar_path', + 'transformation_3d_flow', 'trans_mat', + 'affine_aug')) -> None: self.keys = keys self.meta_keys = meta_keys @@ -99,7 +99,7 @@ def transform(self, results: Union[dict, - img - 'data_samples' (obj:`Det3DDataSample`): The annotation info of - the sample. + the sample. """ # augtest if isinstance(results, list): @@ -116,7 +116,7 @@ def transform(self, results: Union[dict, else: raise NotImplementedError - def pack_single_results(self, results): + def pack_single_results(self, results: dict) -> dict: """Method to pack the single input data. when the value in this dict is a list, it usually is in Augmentations Testing. @@ -132,7 +132,7 @@ def pack_single_results(self, results): - points - img - - 'data_samples' (obj:`Det3DDataSample`): The annotation info + - 'data_samples' (:obj:`Det3DDataSample`): The annotation info of the sample. """ # Format 3D data @@ -220,6 +220,7 @@ def pack_single_results(self, results): return packed_results def __repr__(self) -> str: + """str: Return a string that describes the module.""" repr_str = self.__class__.__name__ repr_str += f'(keys={self.keys})' repr_str += f'(meta_keys={self.meta_keys})' diff --git a/mmdet3d/datasets/transforms/loading.py b/mmdet3d/datasets/transforms/loading.py index 615c1d74f3..15e2d8af8d 100644 --- a/mmdet3d/datasets/transforms/loading.py +++ b/mmdet3d/datasets/transforms/loading.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List +from typing import List, Union import mmcv import mmengine @@ -13,7 +13,7 @@ @TRANSFORMS.register_module() -class LoadMultiViewImageFromFiles(object): +class LoadMultiViewImageFromFiles(BaseTransform): """Load multi channel images from a list of separate channel files. Expects results['img_filename'] to be a list of filenames. @@ -25,11 +25,15 @@ class LoadMultiViewImageFromFiles(object): Defaults to 'unchanged'. """ - def __init__(self, to_float32=False, color_type='unchanged'): + def __init__( + self, + to_float32: bool = False, + color_type: str = 'unchanged' + ) -> None: self.to_float32 = to_float32 self.color_type = color_type - def __call__(self, results): + def transform(self, results: dict) -> dict: """Call function to load multi-view image from files. Args: @@ -139,7 +143,7 @@ class LoadPointsFromMultiSweeps(BaseTransform): Defaults to [0, 1, 2, 4]. file_client_args (dict, optional): Config dict of file clients, refer to - https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py + https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py for more details. Defaults to dict(backend='disk'). pad_empty_sweeps (bool, optional): Whether to repeat keyframe when sweeps is empty. Defaults to False. @@ -150,14 +154,16 @@ class LoadPointsFromMultiSweeps(BaseTransform): Defaults to False. """ - def __init__(self, - sweeps_num=10, - load_dim=5, - use_dim=[0, 1, 2, 4], - file_client_args=dict(backend='disk'), - pad_empty_sweeps=False, - remove_close=False, - test_mode=False): + def __init__( + self, + sweeps_num: int = 10, + load_dim: int = 5, + use_dim: List[int] = [0, 1, 2, 4], + file_client_args: dict = dict(backend='disk'), + pad_empty_sweeps: bool = False, + remove_close: bool = False, + test_mode: bool = False + ) -> None: self.load_dim = load_dim self.sweeps_num = sweeps_num self.use_dim = use_dim @@ -167,7 +173,7 @@ def __init__(self, self.remove_close = remove_close self.test_mode = test_mode - def _load_points(self, pts_filename): + def _load_points(self, pts_filename: str) -> np.ndarray: """Private function to load point clouds data. Args: @@ -189,7 +195,11 @@ def _load_points(self, pts_filename): points = np.fromfile(pts_filename, dtype=np.float32) return points - def _remove_close(self, points, radius=1.0): + def _remove_close( + self, + points: Union[np.ndarray, BasePoints], + radius: float = 1.0 + ) -> Union[np.ndarray, BasePoints]: """Removes point too close within a certain radius from origin. Args: @@ -198,7 +208,7 @@ def _remove_close(self, points, radius=1.0): Defaults to 1.0. Returns: - np.ndarray: Points after removing. + np.ndarray | :obj:`BasePoints`: Points after removing. """ if isinstance(points, np.ndarray): points_numpy = points @@ -211,7 +221,7 @@ def _remove_close(self, points, radius=1.0): not_close = np.logical_not(np.logical_and(x_filt, y_filt)) return points[not_close] - def transform(self, results): + def transform(self, results: dict) -> dict: """Call function to load multi-sweep point clouds from files. Args: @@ -220,7 +230,7 @@ def transform(self, results): Returns: dict: The result dict containing the multi-sweep points data. - Added key and value are described below. + Updated key and value are described below. - points (np.ndarray | :obj:`BasePoints`): Multi-sweep point cloud arrays. @@ -290,7 +300,7 @@ class PointSegClassMapping(BaseTransform): others as len(valid_cat_ids). """ - def transform(self, results: dict) -> None: + def transform(self, results: dict) -> dict: """Call function to map original semantic class to valid category ids. Args: @@ -322,8 +332,6 @@ def transform(self, results: dict) -> None: def __repr__(self): """str: Return a string that describes the module.""" repr_str = self.__class__.__name__ - repr_str += f'(valid_cat_ids={self.valid_cat_ids}, ' - repr_str += f'max_cat_id={self.max_cat_id})' return repr_str @@ -385,13 +393,14 @@ class LoadPointsFromFile(BaseTransform): Args: coord_type (str): The type of coordinates of points cloud. Available options includes: + - 'LIDAR': Points in LiDAR coordinates. - 'DEPTH': Points in depth coordinates, usually for indoor dataset. - 'CAMERA': Points in camera coordinates. load_dim (int, optional): The dimension of the loaded points. Defaults to 6. - use_dim (list[int], optional): Which dimensions of the points to use. - Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4 + use_dim (list[int] | int, optional): Which dimensions of the points + to use. Defaults to [0, 1, 2]. For KITTI dataset, set use_dim=4 or use_dim=[0, 1, 2, 3] to use the intensity dimension. shift_height (bool, optional): Whether to use shifted height. Defaults to False. @@ -399,7 +408,7 @@ class LoadPointsFromFile(BaseTransform): Defaults to False. file_client_args (dict, optional): Config dict of file clients, refer to - https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py + https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py for more details. Defaults to dict(backend='disk'). """ @@ -407,7 +416,7 @@ def __init__( self, coord_type: str, load_dim: int = 6, - use_dim: list = [0, 1, 2], + use_dim: Union[int, List[int]] = [0, 1, 2], shift_height: bool = False, use_color: bool = False, file_client_args: dict = dict(backend='disk') @@ -523,6 +532,7 @@ class LoadAnnotations3D(LoadAnnotations): Required Keys: - ann_info (dict) + - gt_bboxes_3d (:obj:`LiDARInstance3DBoxes` | :obj:`DepthInstance3DBoxes` | :obj:`CameraInstance3DBoxes`): 3D ground truth bboxes. Only when `with_bbox_3d` is True @@ -592,7 +602,7 @@ class LoadAnnotations3D(LoadAnnotations): seg_3d_dtype (dtype, optional): Dtype of 3D semantic masks. Defaults to int64. file_client_args (dict): Config dict of file clients, refer to - https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py + https://github.com/open-mmlab/mmengine/blob/main/mmengine/fileio/file_client.py for more details. """ diff --git a/mmdet3d/datasets/transforms/test_time_aug.py b/mmdet3d/datasets/transforms/test_time_aug.py index 55c3c94069..97a46ba1f2 100644 --- a/mmdet3d/datasets/transforms/test_time_aug.py +++ b/mmdet3d/datasets/transforms/test_time_aug.py @@ -16,7 +16,7 @@ class MultiScaleFlipAug3D(BaseTransform): Args: transforms (list[dict]): Transforms to apply in each augmentation. - img_scale (tuple | list[tuple]: Images scales for resizing. + img_scale (tuple | list[tuple]): Images scales for resizing. pts_scale_ratio (float | list[float]): Points scale ratios for resizing. flip (bool, optional): Whether apply flip augmentation. @@ -25,11 +25,11 @@ class MultiScaleFlipAug3D(BaseTransform): directions for images, options are "horizontal" and "vertical". If flip_direction is list, multiple flip augmentations will be applied. It has no effect when ``flip == False``. - Defaults to "horizontal". - pcd_horizontal_flip (bool, optional): Whether apply horizontal + Defaults to 'horizontal'. + pcd_horizontal_flip (bool, optional): Whether to apply horizontal flip augmentation to point cloud. Defaults to True. Note that it works only when 'flip' is turned on. - pcd_vertical_flip (bool, optional): Whether apply vertical flip + pcd_vertical_flip (bool, optional): Whether to apply vertical flip augmentation to point cloud. Defaults to True. Note that it works only when 'flip' is turned on. """ diff --git a/mmdet3d/datasets/transforms/transforms_3d.py b/mmdet3d/datasets/transforms/transforms_3d.py index 8a08267c3b..56e6601121 100644 --- a/mmdet3d/datasets/transforms/transforms_3d.py +++ b/mmdet3d/datasets/transforms/transforms_3d.py @@ -1,7 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import random import warnings -from typing import Dict, List, Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import cv2 import numpy as np @@ -76,7 +76,6 @@ class RandomFlip3D(RandomFlip): otherwise it will be randomly decided by a ratio specified in the init method. - Required Keys: - points (np.float32) @@ -329,7 +328,7 @@ class ObjectSample(BaseTransform): def __init__(self, db_sampler: dict, sample_2d: bool = False, - use_ground_plane: bool = False): + use_ground_plane: bool = False) -> None: self.sampler_cfg = db_sampler self.sample_2d = sample_2d if 'type' not in db_sampler.keys(): @@ -456,10 +455,10 @@ class ObjectNoise(BaseTransform): """ def __init__(self, - translation_std: list = [0.25, 0.25, 0.25], - global_rot_range: list = [0.0, 0.0], - rot_range: list = [-0.15707963267, 0.15707963267], - num_try: int = 100): + translation_std: List[float] = [0.25, 0.25, 0.25], + global_rot_range: List[float] = [0.0, 0.0], + rot_range: List[float] = [-0.15707963267, 0.15707963267], + num_try: int = 100) -> None: self.translation_std = translation_std self.global_rot_range = global_rot_range self.rot_range = rot_range @@ -522,7 +521,7 @@ class GlobalAlignment(BaseTransform): def __init__(self, rotation_axis: int) -> None: self.rotation_axis = rotation_axis - def _trans_points(self, results: Dict, trans_factor: np.ndarray) -> None: + def _trans_points(self, results: dict, trans_factor: np.ndarray) -> None: """Private function to translate points. Args: @@ -534,7 +533,7 @@ def _trans_points(self, results: Dict, trans_factor: np.ndarray) -> None: """ results['points'].translate(trans_factor) - def _rot_points(self, results: Dict, rot_mat: np.ndarray) -> None: + def _rot_points(self, results: dict, rot_mat: np.ndarray) -> None: """Private function to rotate bounding boxes and points. Args: @@ -560,7 +559,7 @@ def _check_rot_mat(self, rot_mat: np.ndarray) -> None: is_valid &= (rot_mat[:, self.rotation_axis] == valid_array).all() assert is_valid, f'invalid rotation matrix {rot_mat}' - def transform(self, results: Dict) -> Dict: + def transform(self, results: dict) -> dict: """Call function to shuffle points. Args: @@ -586,6 +585,7 @@ def transform(self, results: Dict) -> Dict: return results def __repr__(self): + """str: Return a string that describes the module.""" repr_str = self.__class__.__name__ repr_str += f'(rotation_axis={self.rotation_axis})' return repr_str @@ -804,6 +804,7 @@ def transform(self, input_dict: dict) -> dict: return input_dict def __repr__(self): + """str: Return a string that describes the module.""" return self.__class__.__name__ @@ -823,7 +824,7 @@ class ObjectRangeFilter(BaseTransform): point_cloud_range (list[float]): Point cloud range. """ - def __init__(self, point_cloud_range: list): + def __init__(self, point_cloud_range: List[float]): self.pcd_range = np.array(point_cloud_range, dtype=np.float32) def transform(self, input_dict: dict) -> dict: @@ -885,7 +886,7 @@ class PointsRangeFilter(BaseTransform): point_cloud_range (list[float]): Point cloud range. """ - def __init__(self, point_cloud_range: list): + def __init__(self, point_cloud_range: List[float]) -> None: self.pcd_range = np.array(point_cloud_range, dtype=np.float32) def transform(self, input_dict: dict) -> dict: @@ -938,7 +939,7 @@ class ObjectNameFilter(BaseTransform): classes (list[str]): List of class names to be kept for training. """ - def __init__(self, classes: list): + def __init__(self, classes: List[str]) -> None: self.classes = classes self.labels = list(range(len(self.classes))) @@ -996,34 +997,38 @@ class PointSample(BaseTransform): def __init__(self, num_points: int, - sample_range: float = None, - replace: bool = False): + sample_range: Optional[float] = None, + replace: bool = False) -> None: self.num_points = num_points self.sample_range = sample_range self.replace = replace - def _points_random_sampling(self, - points, - num_samples, - sample_range=None, - replace=False, - return_choices=False): + def _points_random_sampling( + self, + points: BasePoints, + num_samples: int, + sample_range: Optional[float] = None, + replace: bool = False, + return_choices: bool = False + ) -> Union[Tuple[BasePoints, np.ndarray], BasePoints]: """Points random sampling. Sample points to a certain number. Args: - points (np.ndarray | :obj:`BasePoints`): 3D Points. + points (:obj:`BasePoints`): 3D Points. num_samples (int): Number of samples to be sampled. sample_range (float, optional): Indicating the range where the points will be sampled. Defaults to None. replace (bool, optional): Sampling with or without replacement. - Defaults to None. + Defaults to False. return_choices (bool, optional): Whether return choice. Defaults to False. + Returns: - tuple[np.ndarray] | np.ndarray: - - points (np.ndarray | :obj:`BasePoints`): 3D Points. + tuple[:obj:`BasePoints`, np.ndarray] | :obj:`BasePoints`: + + - points (:obj:`BasePoints`): 3D Points. - choices (np.ndarray, optional): The generated random samples. """ if not replace: @@ -1031,7 +1036,7 @@ def _points_random_sampling(self, point_range = range(len(points)) if sample_range is not None and not replace: # Only sampling the near points when len(points) >= num_samples - dist = np.linalg.norm(points.tensor, axis=1) + dist = np.linalg.norm(points.coord.numpy(), axis=1) far_inds = np.where(dist >= sample_range)[0] near_inds = np.where(dist < sample_range)[0] # in case there are too many far points @@ -1055,6 +1060,7 @@ def transform(self, input_dict: dict) -> dict: Args: input_dict (dict): Result dict from loading pipeline. + Returns: dict: Results after sampling, 'points', 'pts_instance_mask' and 'pts_semantic_mask' keys are updated in the result dict. @@ -1214,8 +1220,11 @@ def _input_generation(self, coords: np.ndarray, patch_center: np.ndarray, return points - def _patch_points_sampling(self, points: BasePoints, - sem_mask: np.ndarray) -> BasePoints: + def _patch_points_sampling( + self, + points: BasePoints, + sem_mask: np.ndarray + ) -> Tuple[BasePoints, np.ndarray]: """Patch points sampling. First sample a valid patch. @@ -1226,7 +1235,7 @@ def _patch_points_sampling(self, points: BasePoints, sem_mask (np.ndarray): semantic segmentation mask for input points. Returns: - tuple[:obj:`BasePoints`, np.ndarray] | :obj:`BasePoints`: + tuple[:obj:`BasePoints`, np.ndarray]: - points (:obj:`BasePoints`): 3D Points. - choices (np.ndarray): The generated random samples. @@ -1433,7 +1442,7 @@ def __repr__(self): @TRANSFORMS.register_module() -class VoxelBasedPointSampler(object): +class VoxelBasedPointSampler(BaseTransform): """Voxel based point sampler. Apply voxel sampling to multiple sweep points. @@ -1445,7 +1454,10 @@ class VoxelBasedPointSampler(object): for input points. """ - def __init__(self, cur_sweep_cfg, prev_sweep_cfg=None, time_dim=3): + def __init__(self, + cur_sweep_cfg: dict, + prev_sweep_cfg: Optional[dict] = None, + time_dim: int = 3) -> None: self.cur_voxel_generator = VoxelGenerator(**cur_sweep_cfg) self.cur_voxel_num = self.cur_voxel_generator._max_voxels self.time_dim = time_dim @@ -1458,7 +1470,10 @@ def __init__(self, cur_sweep_cfg, prev_sweep_cfg=None, time_dim=3): self.prev_voxel_generator = None self.prev_voxel_num = 0 - def _sample_points(self, points, sampler, point_dim): + def _sample_points(self, + points: np.ndarray, + sampler: VoxelGenerator, + point_dim: int) -> np.ndarray: """Sample points for each points subset. Args: @@ -1484,7 +1499,7 @@ def _sample_points(self, points, sampler, point_dim): return sample_points - def __call__(self, results): + def transform(self, results: dict) -> dict: """Call function to sample points from multiple sweeps. Args: @@ -1766,6 +1781,7 @@ def _get_ref_point(self, ref_point1: np.ndarray, return ref_point3 def __repr__(self): + """str: Return a string that describes the module.""" repr_str = self.__class__.__name__ repr_str += f'(img_scale={self.img_scale}, ' repr_str += f'down_ratio={self.down_ratio}) ' @@ -1786,7 +1802,7 @@ class RandomShiftScale(BaseTransform): aug_prob (float): The shifting and scaling probability. """ - def __init__(self, shift_scale: Tuple[float], aug_prob: float): + def __init__(self, shift_scale: Tuple[float], aug_prob: float) -> None: self.shift_scale = shift_scale self.aug_prob = aug_prob @@ -1825,6 +1841,7 @@ def transform(self, results: dict) -> dict: return results def __repr__(self): + """str: Return a string that describes the module.""" repr_str = self.__class__.__name__ repr_str += f'(shift_scale={self.shift_scale}, ' repr_str += f'aug_prob={self.aug_prob}) '