Skip to content

Commit

Permalink
Merge 5577f50 into aa10365
Browse files Browse the repository at this point in the history
  • Loading branch information
Xiangxu-0103 authored Apr 22, 2023
2 parents aa10365 + 5577f50 commit 3fb5aa3
Show file tree
Hide file tree
Showing 17 changed files with 475 additions and 114 deletions.
48 changes: 29 additions & 19 deletions configs/_base_/datasets/s3dis-seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,6 @@
with_seg_3d=True,
backend_args=backend_args),
dict(type='NormalizePointsColor', color_mean=None),
dict(
# a wrapper in order to successfully call test function
# actually we don't perform test-time-aug
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.0,
flip_ratio_bev_vertical=0.0),
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
# construct a pipeline for data and gt loading in show function
Expand All @@ -109,6 +90,33 @@
dict(type='NormalizePointsColor', color_mean=None),
dict(type='Pack3DDetInputs', keys=['points'])
]
tta_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5],
backend_args=backend_args),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True,
backend_args=backend_args),
dict(type='NormalizePointsColor', color_mean=None),
dict(
type='TestTimeAug',
transforms=[[
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.,
flip_ratio_bev_vertical=0.)
], [dict(type='Pack3DDetInputs', keys=['points'])]])
]

# train on area 1, 2, 3, 4, 6
# test on area 5
Expand Down Expand Up @@ -157,3 +165,5 @@
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')

tta_model = dict(type='Seg3DTTAModel')
48 changes: 29 additions & 19 deletions configs/_base_/datasets/scannet-seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,25 +73,6 @@
with_seg_3d=True,
backend_args=backend_args),
dict(type='NormalizePointsColor', color_mean=None),
dict(
# a wrapper in order to successfully call test function
# actually we don't perform test-time-aug
type='MultiScaleFlipAug3D',
img_scale=(1333, 800),
pts_scale_ratio=1,
flip=False,
transforms=[
dict(
type='GlobalRotScaleTrans',
rot_range=[0, 0],
scale_ratio_range=[1., 1.],
translation_std=[0, 0, 0]),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.0,
flip_ratio_bev_vertical=0.0),
]),
dict(type='Pack3DDetInputs', keys=['points'])
]
# construct a pipeline for data and gt loading in show function
Expand All @@ -109,6 +90,33 @@
dict(type='NormalizePointsColor', color_mean=None),
dict(type='Pack3DDetInputs', keys=['points'])
]
tta_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='DEPTH',
shift_height=False,
use_color=True,
load_dim=6,
use_dim=[0, 1, 2, 3, 4, 5],
backend_args=backend_args),
dict(
type='LoadAnnotations3D',
with_bbox_3d=False,
with_label_3d=False,
with_mask_3d=False,
with_seg_3d=True,
backend_args=backend_args),
dict(type='NormalizePointsColor', color_mean=None),
dict(
type='TestTimeAug',
transforms=[[
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.,
flip_ratio_bev_vertical=0.)
], [dict(type='Pack3DDetInputs', keys=['points'])]])
]

train_dataloader = dict(
batch_size=8,
Expand Down Expand Up @@ -152,3 +160,5 @@
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')

tta_model = dict(type='Seg3DTTAModel')
100 changes: 70 additions & 30 deletions configs/_base_/datasets/semantickitti.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
seg_offset=2**16,
dataset_type='semantickitti',
backend_args=backend_args),
dict(type='PointSegClassMapping', ),
dict(type='PointSegClassMapping'),
dict(
type='RandomFlip3D',
sync_2d=False,
Expand Down Expand Up @@ -112,12 +112,21 @@
seg_offset=2**16,
dataset_type='semantickitti',
backend_args=backend_args),
dict(type='PointSegClassMapping', ),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
dict(type='PointSegClassMapping'),
dict(type='Pack3DDetInputs', keys=['points'])
]
# construct a pipeline for data and gt loading in show function
# please keep its loading function consistent with test_pipeline (e.g. client)
eval_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
load_dim=4,
use_dim=4,
backend_args=backend_args),
dict(type='Pack3DDetInputs', keys=['points'])
]
tta_pipeline = [
dict(
type='LoadPointsFromFile',
coord_type='LIDAR',
Expand All @@ -133,46 +142,75 @@
seg_offset=2**16,
dataset_type='semantickitti',
backend_args=backend_args),
dict(type='PointSegClassMapping', ),
dict(type='Pack3DDetInputs', keys=['points', 'pts_semantic_mask'])
dict(type='PointSegClassMapping'),
dict(
type='TestTimeAug',
transforms=[[
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.,
flip_ratio_bev_vertical=0.),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=0.,
flip_ratio_bev_vertical=1.),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=1.,
flip_ratio_bev_vertical=0.),
dict(
type='RandomFlip3D',
sync_2d=False,
flip_ratio_bev_horizontal=1.,
flip_ratio_bev_vertical=1.)
],
[
dict(
type='GlobalRotScaleTrans',
rot_range=[pcd_rotate_range, pcd_rotate_range],
scale_ratio_range=[
pcd_scale_factor, pcd_scale_factor
],
translation_std=[0, 0, 0])
for pcd_rotate_range in [-0.78539816, 0.0, 0.78539816]
for pcd_scale_factor in [0.95, 1.0, 1.05]
], [dict(type='Pack3DDetInputs', keys=['points'])]])
]

train_dataloader = dict(
batch_size=2,
num_workers=4,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
dataset=dict(
type='RepeatDataset',
times=1,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='semantickitti_infos_train.pkl',
pipeline=train_pipeline,
metainfo=metainfo,
modality=input_modality,
ignore_index=19,
backend_args=backend_args)),
)
type=dataset_type,
data_root=data_root,
ann_file='semantickitti_infos_train.pkl',
pipeline=train_pipeline,
metainfo=metainfo,
modality=input_modality,
ignore_index=19,
backend_args=backend_args))

test_dataloader = dict(
batch_size=1,
num_workers=1,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type='RepeatDataset',
times=1,
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='semantickitti_infos_val.pkl',
pipeline=test_pipeline,
metainfo=metainfo,
modality=input_modality,
ignore_index=19,
test_mode=True,
backend_args=backend_args)),
)
type=dataset_type,
data_root=data_root,
ann_file='semantickitti_infos_val.pkl',
pipeline=test_pipeline,
metainfo=metainfo,
modality=input_modality,
ignore_index=19,
test_mode=True,
backend_args=backend_args))

val_dataloader = test_dataloader

Expand All @@ -182,3 +220,5 @@
vis_backends = [dict(type='LocalVisBackend')]
visualizer = dict(
type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer')

tta_model = dict(type='Seg3DTTAModel')
2 changes: 1 addition & 1 deletion configs/minkunet/minkunet_w32_8xb2-15e_semantickitti.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
]

train_dataloader = dict(
sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline)))
sampler=dict(seed=0), dataset=dict(pipeline=train_pipeline))

lr = 0.24
optim_wrapper = dict(
Expand Down
2 changes: 1 addition & 1 deletion configs/spvcnn/spvcnn_w32_8xb2-15e_semantickitti.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
]

train_dataloader = dict(
sampler=dict(seed=0), dataset=dict(dataset=dict(pipeline=train_pipeline)))
sampler=dict(seed=0), dataset=dict(pipeline=train_pipeline))

lr = 0.24
optim_wrapper = dict(
Expand Down
8 changes: 4 additions & 4 deletions mmdet3d/datasets/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dbsampler import DataBaseSampler
from .formating import Pack3DDetInputs
from .loading import (LoadAnnotations3D, LoadImageFromFileMono3D,
LoadMultiViewImageFromFiles, LoadPointsFromDict,
LoadPointsFromFile, LoadPointsFromMultiSweeps,
MonoDet3DInferencerLoader,
from .loading import (LidarDet3DInferencerLoader, LoadAnnotations3D,
LoadImageFromFileMono3D, LoadMultiViewImageFromFiles,
LoadPointsFromDict, LoadPointsFromFile,
LoadPointsFromMultiSweeps, MonoDet3DInferencerLoader,
MultiModalityDet3DInferencerLoader, NormalizePointsColor,
PointSegClassMapping)
from .test_time_aug import MultiScaleFlipAug3D
Expand Down
6 changes: 5 additions & 1 deletion mmdet3d/models/segmentors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,9 @@
from .cylinder3d import Cylinder3D
from .encoder_decoder import EncoderDecoder3D
from .minkunet import MinkUNet
from .seg3d_tta import Seg3DTTAModel

__all__ = ['Base3DSegmentor', 'EncoderDecoder3D', 'Cylinder3D', 'MinkUNet']
__all__ = [
'Base3DSegmentor', 'EncoderDecoder3D', 'Cylinder3D', 'MinkUNet',
'Seg3DTTAModel'
]
26 changes: 14 additions & 12 deletions mmdet3d/models/segmentors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,17 +132,12 @@ def _forward(self,
"""
pass

@abstractmethod
def aug_test(self, batch_inputs, batch_data_samples):
"""Placeholder for augmentation test."""
pass

def postprocess_result(self, seg_pred_list: List[dict],
def postprocess_result(self, seg_logits_list: List[Tensor],
batch_data_samples: SampleList) -> SampleList:
"""Convert results list to `Det3DDataSample`.
Args:
seg_logits_list (List[dict]): List of segmentation results,
seg_logits_list (List[Tensor]): List of segmentation results,
seg_logits from model of each input point clouds sample.
batch_data_samples (List[:obj:`Det3DDataSample`]): The det3d data
samples. It usually includes information such as `metainfo` and
Expand All @@ -152,12 +147,19 @@ def postprocess_result(self, seg_pred_list: List[dict],
List[:obj:`Det3DDataSample`]: Segmentation results of the input
points. Each Det3DDataSample usually contains:
- ``pred_pts_seg`` (PixelData): Prediction of 3D semantic
- ``pred_pts_seg`` (PointData): Prediction of 3D semantic
segmentation.
- ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic
segmentation before normalization.
"""

for i in range(len(seg_pred_list)):
seg_pred = seg_pred_list[i]
batch_data_samples[i].set_data(
{'pred_pts_seg': PointData(**{'pts_semantic_mask': seg_pred})})
for i in range(len(seg_logits_list)):
seg_logits = seg_logits_list[i]
seg_pred = seg_logits.argmax(dim=0)
batch_data_samples[i].set_data({
'pts_seg_logits':
PointData(**{'pts_seg_logits': seg_logits}),
'pred_pts_seg':
PointData(**{'pts_semantic_mask': seg_pred})
})
return batch_data_samples
14 changes: 8 additions & 6 deletions mmdet3d/models/segmentors/cylinder3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,18 @@ def predict(self,
List[:obj:`Det3DDataSample`]: Segmentation results of the input
points. Each Det3DDataSample usually contains:
- ``pred_pts_seg`` (PixelData): Prediction of 3D semantic
- ``pred_pts_seg`` (PointData): Prediction of 3D semantic
segmentation.
- ``pts_seg_logits`` (PointData): Predicted logits of 3D semantic
segmentation before normalization.
"""
# 3D segmentation requires per-point prediction, so it's impossible
# to use down-sampling to get a batch of scenes with same num_points
# therefore, we only support testing one scene every time
x = self.extract_feat(batch_inputs_dict)
seg_pred_list = self.decode_head.predict(x, batch_inputs_dict,
batch_data_samples)
for i in range(len(seg_pred_list)):
seg_pred_list[i] = seg_pred_list[i].argmax(1).cpu()
seg_logits_list = self.decode_head.predict(x, batch_inputs_dict,
batch_data_samples)
for i in range(len(seg_logits_list)):
seg_logits_list[i] = seg_logits_list[i].transpose(0, 1)

return self.postprocess_result(seg_pred_list, batch_data_samples)
return self.postprocess_result(seg_logits_list, batch_data_samples)
Loading

0 comments on commit 3fb5aa3

Please sign in to comment.