From 52a2503f341b93ef0a87887d38b2986b3549381b Mon Sep 17 00:00:00 2001 From: VVsssssk Date: Tue, 7 Jun 2022 13:40:11 +0800 Subject: [PATCH 1/2] fix parta2 bug --- ...v_PartA2_secfpn_2x8_cyclic_80e_kitti-3d-3class.py | 3 +++ mmdet3d/models/detectors/two_stage.py | 3 ++- mmdet3d/models/middle_encoders/sparse_unet.py | 9 +++++---- .../models/roi_heads/bbox_heads/parta2_bbox_head.py | 12 ++++++++++-- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/configs/parta2/hv_PartA2_secfpn_2x8_cyclic_80e_kitti-3d-3class.py b/configs/parta2/hv_PartA2_secfpn_2x8_cyclic_80e_kitti-3d-3class.py index 4bb3b2c94c..1166231890 100644 --- a/configs/parta2/hv_PartA2_secfpn_2x8_cyclic_80e_kitti-3d-3class.py +++ b/configs/parta2/hv_PartA2_secfpn_2x8_cyclic_80e_kitti-3d-3class.py @@ -90,6 +90,7 @@ pipeline=train_pipeline, modality=input_modality, classes=class_names, + box_type_3d='LiDAR', test_mode=False)), val=dict( type=dataset_type, @@ -100,6 +101,7 @@ pipeline=test_pipeline, modality=input_modality, classes=class_names, + box_type_3d='LiDAR', test_mode=True), test=dict( type=dataset_type, @@ -110,6 +112,7 @@ pipeline=test_pipeline, modality=input_modality, classes=class_names, + box_type_3d='LiDAR', test_mode=True)) # Part-A2 uses a different learning rate from what SECOND uses. diff --git a/mmdet3d/models/detectors/two_stage.py b/mmdet3d/models/detectors/two_stage.py index 06a036b0c8..707f706d55 100644 --- a/mmdet3d/models/detectors/two_stage.py +++ b/mmdet3d/models/detectors/two_stage.py @@ -30,7 +30,8 @@ def __init__(self, 'please use "init_cfg" instead') backbone.pretrained = pretrained self.backbone = build_backbone(backbone) - + self.train_cfg = train_cfg + self.test_cfg = test_cfg if neck is not None: self.neck = build_neck(neck) diff --git a/mmdet3d/models/middle_encoders/sparse_unet.py b/mmdet3d/models/middle_encoders/sparse_unet.py index c8af2ed0b6..005e34ebeb 100644 --- a/mmdet3d/models/middle_encoders/sparse_unet.py +++ b/mmdet3d/models/middle_encoders/sparse_unet.py @@ -11,6 +11,7 @@ from mmcv.runner import BaseModule, auto_fp16 from mmdet3d.ops import SparseBasicBlock, make_sparse_convmodule +from mmdet3d.ops.sparse_block import replace_feature from ..builder import MIDDLE_ENCODERS @@ -168,10 +169,11 @@ def decoder_layer_forward(self, x_lateral, x_bottom, lateral_layer, :obj:`SparseConvTensor`: Upsampled feature. """ x = lateral_layer(x_lateral) - x.features = torch.cat((x_bottom.features, x.features), dim=1) + x = replace_feature(x, torch.cat((x_bottom.features, x.features), + dim=1)) x_merge = merge_layer(x) x = self.reduce_channel(x, x_merge.features.shape[1]) - x.features = x_merge.features + x.features + x = replace_feature(x, x_merge.features + x.features) x = upsample_layer(x) return x @@ -191,8 +193,7 @@ def reduce_channel(x, out_channels): n, in_channels = features.shape assert (in_channels % out_channels == 0) and (in_channels >= out_channels) - - x.features = features.view(n, out_channels, -1).sum(dim=2) + x = replace_feature(x, features.view(n, out_channels, -1).sum(dim=2)) return x def make_encoder_layers(self, make_block, norm_cfg, in_channels): diff --git a/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py b/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py index c569c4e353..6f5ea722b9 100644 --- a/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py +++ b/mmdet3d/models/roi_heads/bbox_heads/parta2_bbox_head.py @@ -2,7 +2,15 @@ import numpy as np import torch from mmcv.cnn import ConvModule, normal_init -from mmcv.ops import SparseConvTensor, SparseMaxPool3d, SparseSequential + +from mmdet3d.ops.spconv import IS_SPCONV2_AVAILABLE + +if IS_SPCONV2_AVAILABLE: + from spconv.pytorch import (SparseConvTensor, SparseMaxPool3d, + SparseSequential) +else: + from mmcv.ops import SparseConvTensor, SparseMaxPool3d, SparseSequential + from mmcv.runner import BaseModule from torch import nn as nn @@ -252,7 +260,7 @@ def forward(self, seg_feats, part_feats): sparse_idx[:, 2], sparse_idx[:, 3]] seg_features = seg_feats[sparse_idx[:, 0], sparse_idx[:, 1], sparse_idx[:, 2], sparse_idx[:, 3]] - coords = sparse_idx.int() + coords = sparse_idx.int().contiguous() part_features = SparseConvTensor(part_features, coords, sparse_shape, rcnn_batch_size) seg_features = SparseConvTensor(seg_features, coords, sparse_shape, From c18e77ba9262edd88082dbf68a84c1b6bf61cf18 Mon Sep 17 00:00:00 2001 From: VVsssssk Date: Fri, 5 Aug 2022 11:41:02 +0800 Subject: [PATCH 2/2] fix spconv2 bug --- .../spconv/overwrite_spconv/write_spconv2.py | 29 ++++--------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/mmdet3d/ops/spconv/overwrite_spconv/write_spconv2.py b/mmdet3d/ops/spconv/overwrite_spconv/write_spconv2.py index 237051ebcc..362ff7ae2f 100644 --- a/mmdet3d/ops/spconv/overwrite_spconv/write_spconv2.py +++ b/mmdet3d/ops/spconv/overwrite_spconv/write_spconv2.py @@ -34,30 +34,11 @@ def register_spconv2(): CONV_LAYERS._register_module(SubMConv2d, 'SubMConv2d', force=True) CONV_LAYERS._register_module(SubMConv3d, 'SubMConv3d', force=True) CONV_LAYERS._register_module(SubMConv4d, 'SubMConv4d', force=True) + SparseModule._version = 2 SparseModule._load_from_state_dict = _load_from_state_dict - SparseModule._save_to_state_dict = _save_to_state_dict return True -def _save_to_state_dict(self, destination, prefix, keep_vars): - """Rewrite this func to compat the convolutional kernel weights between - spconv 1.x in MMCV and 2.x in spconv2.x. - - Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) , - while those in spcon2.x is in (out_channel,D,H,W,in_channel). - """ - for name, param in self._parameters.items(): - if param is not None: - param = param if keep_vars else param.detach() - if name == 'weight': - dims = list(range(1, len(param.shape))) + [0] - param = param.permute(*dims) - destination[prefix + name] = param - for name, buf in self._buffers.items(): - if buf is not None and name not in self._non_persistent_buffers_set: - destination[prefix + name] = buf if keep_vars else buf.detach() - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): """Rewrite this func to compat the convolutional kernel weights between @@ -66,6 +47,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) , while those in spcon2.x is in (out_channel,D,H,W,in_channel). """ + version = local_metadata.get('version', None) for hook in self._load_state_dict_pre_hooks.values(): hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) @@ -83,9 +65,10 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, # 0.3.* to version 0.4+ if len(param.shape) == 0 and len(input_param.shape) == 1: input_param = input_param[0] - dims = [len(input_param.shape) - 1] + list( - range(len(input_param.shape) - 1)) - input_param = input_param.permute(*dims) + if version != 2: + dims = [len(input_param.shape) - 1] + list( + range(len(input_param.shape) - 1)) + input_param = input_param.permute(*dims) if input_param.shape != param.shape: # local shape should match the one in checkpoint error_msgs.append(