Skip to content

Commit

Permalink
[Enhance] Sampling points based on distance metric (#667)
Browse files Browse the repository at this point in the history
* [Enhance] Sampling points based on distance metric

* fix typo

* refine unittest

* refine unittest

* refine details & add unittest & refine configs

* remove __repr__ & rename arg

* fix unittest

* add unitest

* refine unittest

* refine code

* refine code

* refine depth calculation

* refine code
  • Loading branch information
wHao-Wu authored Aug 5, 2021
1 parent d3213cd commit fc9e0d9
Show file tree
Hide file tree
Showing 15 changed files with 157 additions and 61 deletions.
4 changes: 2 additions & 2 deletions configs/3dssd/3dssd_4x4_kitti-3d-car.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
scale_ratio_range=[0.9, 1.1]),
# 3DSSD can get a higher performance without this transform
# dict(type='BackgroundPointsFilter', bbox_enlarge_range=(0.5, 2.0, 0.5)),
dict(type='IndoorPointSample', num_points=16384),
dict(type='PointSample', num_points=16384),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
Expand All @@ -78,7 +78,7 @@
dict(type='RandomFlip3D'),
dict(
type='PointsRangeFilter', point_cloud_range=point_cloud_range),
dict(type='IndoorPointSample', num_points=16384),
dict(type='PointSample', num_points=16384),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
Expand Down
4 changes: 2 additions & 2 deletions configs/_base_/datasets/sunrgbd-3d-10class.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15],
shift_height=True),
dict(type='IndoorPointSample', num_points=20000),
dict(type='PointSample', num_points=20000),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(type='Collect3D', keys=['points', 'gt_bboxes_3d', 'gt_labels_3d'])
]
Expand All @@ -47,7 +47,7 @@
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
),
dict(type='IndoorPointSample', num_points=20000),
dict(type='PointSample', num_points=20000),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
Expand Down
4 changes: 2 additions & 2 deletions configs/imvotenet/imvotenet_stage2_16x8_sunrgbd-3d-10class.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15],
shift_height=True),
dict(type='IndoorPointSample', num_points=20000),
dict(type='PointSample', num_points=20000),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(
type='Collect3D',
Expand Down Expand Up @@ -225,7 +225,7 @@
sync_2d=False,
flip_ratio_bev_horizontal=0.5,
),
dict(type='IndoorPointSample', num_points=20000),
dict(type='PointSample', num_points=20000),
dict(
type='DefaultFormatBundle3D',
class_names=class_names,
Expand Down
10 changes: 5 additions & 5 deletions docs/tutorials/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ train_pipeline = [ # Training pipeline, refer to mmdet3d.datasets.pipelines for
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33, 34,
36, 39), # all valid categories ids
max_cat_id=40), # max possible category id in input segmentation mask
dict(type='IndoorPointSample', # Sample indoor points, refer to mmdet3d.datasets.pipelines.indoor_sample for more details
dict(type='PointSample', # Sample points, refer to mmdet3d.datasets.pipelines.transforms_3d for more details
num_points=40000), # Number of points to be sampled
dict(type='IndoorFlipData', # Augmentation pipeline that flip points and 3d boxes
flip_ratio_yz=0.5, # Probability of being flipped along yz plane
Expand Down Expand Up @@ -232,7 +232,7 @@ test_pipeline = [ # Testing pipeline, refer to mmdet3d.datasets.pipelines for m
shift_height=True, # Whether to use shifted height
load_dim=6, # The dimension of the loaded points
use_dim=[0, 1, 2]), # Which dimensions of the points to be used
dict(type='IndoorPointSample', # Sample indoor points, refer to mmdet3d.datasets.pipelines.indoor_sample for more details
dict(type='PointSample', # Sample points, refer to mmdet3d.datasets.pipelines.transforms_3d for more details
num_points=40000), # Number of points to be sampled
dict(
type='DefaultFormatBundle3D', # Default format bundle to gather data in the pipeline, refer to mmdet3d.datasets.pipelines.formating for more details
Expand Down Expand Up @@ -286,7 +286,7 @@ data = dict(
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24,
28, 33, 34, 36, 39),
max_cat_id=40),
dict(type='IndoorPointSample', num_points=40000),
dict(type='PointSample', num_points=40000),
dict(
type='IndoorFlipData',
flip_ratio_yz=0.5,
Expand Down Expand Up @@ -325,7 +325,7 @@ data = dict(
shift_height=True,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='IndoorPointSample', num_points=40000),
dict(type='PointSample', num_points=40000),
dict(
type='DefaultFormatBundle3D',
class_names=('cabinet', 'bed', 'chair', 'sofa', 'table',
Expand All @@ -350,7 +350,7 @@ data = dict(
shift_height=True,
load_dim=6,
use_dim=[0, 1, 2]),
dict(type='IndoorPointSample', num_points=40000),
dict(type='PointSample', num_points=40000),
dict(
type='DefaultFormatBundle3D',
class_names=('cabinet', 'bed', 'chair', 'sofa', 'table',
Expand Down
6 changes: 6 additions & 0 deletions mmdet3d/core/points/base_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ def __getitem__(self, item):
Nonzero elements in the vector will be selected.
4. `new_points = points[3:11, vector]`:
return a slice of points and attribute dims.
5. `new_points = points[4:12, 2]`:
return a slice of points with single attribute.
Note that the returned Points might share storage with this Points,
subject to Pytorch's indexing semantics.
Expand All @@ -303,6 +305,10 @@ def __getitem__(self, item):
item = list(item)
item[1] = list(range(start, stop, step))
item = tuple(item)
elif isinstance(item[1], int):
item = list(item)
item[1] = [item[1]]
item = tuple(item)
p = self.tensor[item[0], item[1]]

keep_dims = list(
Expand Down
20 changes: 12 additions & 8 deletions mmdet3d/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@
from .lyft_dataset import LyftDataset
from .nuscenes_dataset import NuScenesDataset
from .nuscenes_mono_dataset import NuScenesMonoDataset
# yapf: disable
from .pipelines import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, IndoorPatchPointSample,
IndoorPointSample, LoadAnnotations3D,
LoadPointsFromFile, LoadPointsFromMultiSweeps,
NormalizePointsColor, ObjectNameFilter, ObjectNoise,
ObjectRangeFilter, ObjectSample, PointShuffle,
PointsRangeFilter, RandomDropPointsColor, RandomFlip3D,
RandomJitterPoints, VoxelBasedPointSampler)
ObjectRangeFilter, ObjectSample, PointSample,
PointShuffle, PointsRangeFilter, RandomDropPointsColor,
RandomFlip3D, RandomJitterPoints,
VoxelBasedPointSampler)
# yapf: enable
from .s3dis_dataset import S3DISSegDataset
from .scannet_dataset import ScanNetDataset, ScanNetSegDataset
from .semantickitti_dataset import SemanticKITTIDataset
Expand All @@ -30,9 +33,10 @@
'ObjectNoise', 'GlobalRotScaleTrans', 'PointShuffle', 'ObjectRangeFilter',
'PointsRangeFilter', 'Collect3D', 'LoadPointsFromFile', 'S3DISSegDataset',
'NormalizePointsColor', 'IndoorPatchPointSample', 'IndoorPointSample',
'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset', 'ScanNetDataset',
'ScanNetSegDataset', 'SemanticKITTIDataset', 'Custom3DDataset',
'Custom3DSegDataset', 'LoadPointsFromMultiSweeps', 'WaymoDataset',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'get_loading_pipeline',
'RandomDropPointsColor', 'RandomJitterPoints', 'ObjectNameFilter'
'PointSample', 'LoadAnnotations3D', 'GlobalAlignment', 'SUNRGBDDataset',
'ScanNetDataset', 'ScanNetSegDataset', 'SemanticKITTIDataset',
'Custom3DDataset', 'Custom3DSegDataset', 'LoadPointsFromMultiSweeps',
'WaymoDataset', 'BackgroundPointsFilter', 'VoxelBasedPointSampler',
'get_loading_pipeline', 'RandomDropPointsColor', 'RandomJitterPoints',
'ObjectNameFilter'
]
17 changes: 9 additions & 8 deletions mmdet3d/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,20 @@
from .transforms_3d import (BackgroundPointsFilter, GlobalAlignment,
GlobalRotScaleTrans, IndoorPatchPointSample,
IndoorPointSample, ObjectNameFilter, ObjectNoise,
ObjectRangeFilter, ObjectSample, PointShuffle,
PointsRangeFilter, RandomDropPointsColor,
RandomFlip3D, RandomJitterPoints,
VoxelBasedPointSampler)
ObjectRangeFilter, ObjectSample, PointSample,
PointShuffle, PointsRangeFilter,
RandomDropPointsColor, RandomFlip3D,
RandomJitterPoints, VoxelBasedPointSampler)

__all__ = [
'ObjectSample', 'RandomFlip3D', 'ObjectNoise', 'GlobalRotScaleTrans',
'PointShuffle', 'ObjectRangeFilter', 'PointsRangeFilter', 'Collect3D',
'Compose', 'LoadMultiViewImageFromFiles', 'LoadPointsFromFile',
'DefaultFormatBundle', 'DefaultFormatBundle3D', 'DataBaseSampler',
'NormalizePointsColor', 'LoadAnnotations3D', 'IndoorPointSample',
'PointSegClassMapping', 'MultiScaleFlipAug3D', 'LoadPointsFromMultiSweeps',
'BackgroundPointsFilter', 'VoxelBasedPointSampler', 'GlobalAlignment',
'IndoorPatchPointSample', 'LoadImageFromFileMono3D', 'ObjectNameFilter',
'RandomDropPointsColor', 'RandomJitterPoints'
'PointSample', 'PointSegClassMapping', 'MultiScaleFlipAug3D',
'LoadPointsFromMultiSweeps', 'BackgroundPointsFilter',
'VoxelBasedPointSampler', 'GlobalAlignment', 'IndoorPatchPointSample',
'LoadImageFromFileMono3D', 'ObjectNameFilter', 'RandomDropPointsColor',
'RandomJitterPoints'
]
88 changes: 66 additions & 22 deletions mmdet3d/datasets/pipelines/transforms_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,45 +838,61 @@ def __repr__(self):


@PIPELINES.register_module()
class IndoorPointSample(object):
"""Indoor point sample.
class PointSample(object):
"""Point sample.
Sampling data to a certain number.
Args:
name (str): Name of the dataset.
num_points (int): Number of points to be sampled.
sample_range (float, optional): The range where to sample points.
"""

def __init__(self, num_points):
def __init__(self, num_points, sample_range=None, replace=False):
self.num_points = num_points

def points_random_sampling(self,
points,
num_samples,
replace=None,
return_choices=False):
self.sample_range = sample_range
self.replace = replace

def _points_random_sampling(self,
points,
num_samples,
sample_range=None,
replace=False,
return_choices=False):
"""Points random sampling.
Sample points to a certain number.
Args:
points (np.ndarray | :obj:`BasePoints`): 3D Points.
num_samples (int): Number of samples to be sampled.
replace (bool): Whether the sample is with or without replacement.
Defaults to None.
return_choices (bool): Whether return choice. Defaults to False.
sample_range (float, optional): Indicating the range where the
points will be sampled.
Defaults to None.
replace (bool, optional): Sampling with or without replacement.
Defaults to None.
return_choices (bool, optional): Whether return choice.
Defaults to False.
Returns:
tuple[np.ndarray] | np.ndarray:
- points (np.ndarray | :obj:`BasePoints`): 3D Points.
- choices (np.ndarray, optional): The generated random samples.
"""
if replace is None:
if not replace:
replace = (points.shape[0] < num_samples)
choices = np.random.choice(
points.shape[0], num_samples, replace=replace)
point_range = range(len(points))
if sample_range is not None and not replace:
# Only sampling the near points when len(points) >= num_samples
depth = np.linalg.norm(points.tensor, axis=1)
far_inds = np.where(depth > sample_range)[0]
near_inds = np.where(depth <= sample_range)[0]
point_range = near_inds
num_samples -= len(far_inds)
choices = np.random.choice(point_range, num_samples, replace=replace)
if sample_range is not None and not replace:
choices = np.concatenate((far_inds, choices))
# Shuffle points after sampling
np.random.shuffle(choices)
if return_choices:
return points[choices], choices
else:
Expand All @@ -887,14 +903,23 @@ def __call__(self, results):
Args:
input_dict (dict): Result dict from loading pipeline.
Returns:
dict: Results after sampling, 'points', 'pts_instance_mask' \
and 'pts_semantic_mask' keys are updated in the result dict.
"""
from mmdet3d.core.points import CameraPoints
points = results['points']
points, choices = self.points_random_sampling(
points, self.num_points, return_choices=True)
# Points in Camera coord can provide the depth information.
# TODO: Need to suport distance-based sampling for other coord system.
if self.sample_range is not None:
assert isinstance(points, CameraPoints), \
'Sampling based on distance is only appliable for CAMERA coord'
points, choices = self._points_random_sampling(
points,
self.num_points,
self.sample_range,
self.replace,
return_choices=True)
results['points'] = points

pts_instance_mask = results.get('pts_instance_mask', None)
Expand All @@ -913,10 +938,29 @@ def __call__(self, results):
def __repr__(self):
"""str: Return a string that describes the module."""
repr_str = self.__class__.__name__
repr_str += f'(num_points={self.num_points})'
repr_str += f'(num_points={self.num_points},'
repr_str += f' sample_range={self.sample_range})'

return repr_str


@PIPELINES.register_module()
class IndoorPointSample(PointSample):
"""Indoor point sample.
Sampling data to a certain number.
NOTE: IndoorPointSample is deprecated in favor of PointSample
Args:
num_points (int): Number of points to be sampled.
"""

def __init__(self, *args, **kwargs):
warnings.warn(
'IndoorPointSample is deprecated in favor of PointSample')
super(IndoorPointSample, self).__init__(*args, **kwargs)


@PIPELINES.register_module()
class IndoorPatchPointSample(object):
r"""Indoor point sample within a patch. Modified from `PointNet++ <https://
Expand Down
2 changes: 1 addition & 1 deletion tests/test_data/test_datasets/test_scannet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_getitem():
type='PointSegClassMapping',
valid_cat_ids=(3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 14, 16, 24, 28, 33,
34, 36, 39)),
dict(type='IndoorPointSample', num_points=5),
dict(type='PointSample', num_points=5),
dict(
type='RandomFlip3D',
sync_2d=False,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_data/test_datasets/test_sunrgbd_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _generate_sunrgbd_dataset_config():
rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15],
shift_height=True),
dict(type='IndoorPointSample', num_points=5),
dict(type='PointSample', num_points=5),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(
type='Collect3D',
Expand Down Expand Up @@ -73,7 +73,7 @@ def _generate_sunrgbd_multi_modality_dataset_config():
rot_range=[-0.523599, 0.523599],
scale_ratio_range=[0.85, 1.15],
shift_height=True),
dict(type='IndoorPointSample', num_points=5),
dict(type='PointSample', num_points=5),
dict(type='DefaultFormatBundle3D', class_names=class_names),
dict(
type='Collect3D',
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_multi_scale_flip_aug_3D():
'sync_2d': False,
'flip_ratio_bev_horizontal': 0.5
}, {
'type': 'IndoorPointSample',
'type': 'PointSample',
'num_points': 5
}, {
'type':
Expand Down
Loading

0 comments on commit fc9e0d9

Please sign in to comment.