Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug]Fix spconv2 model load bug #1699

Merged
merged 7 commits into from
Aug 5, 2022
Merged
29 changes: 6 additions & 23 deletions mmdet3d/ops/spconv/overwrite_spconv/write_spconv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,11 @@ def register_spconv2():
CONV_LAYERS._register_module(SubMConv2d, 'SubMConv2d', force=True)
CONV_LAYERS._register_module(SubMConv3d, 'SubMConv3d', force=True)
CONV_LAYERS._register_module(SubMConv4d, 'SubMConv4d', force=True)
SparseModule._version = 2
SparseModule._load_from_state_dict = _load_from_state_dict
SparseModule._save_to_state_dict = _save_to_state_dict
return True


def _save_to_state_dict(self, destination, prefix, keep_vars):
"""Rewrite this func to compat the convolutional kernel weights between
spconv 1.x in MMCV and 2.x in spconv2.x.

Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
while those in spcon2.x is in (out_channel,D,H,W,in_channel).
"""
for name, param in self._parameters.items():
if param is not None:
param = param if keep_vars else param.detach()
if name == 'weight':
dims = list(range(1, len(param.shape))) + [0]
param = param.permute(*dims)
destination[prefix + name] = param
for name, buf in self._buffers.items():
if buf is not None and name not in self._non_persistent_buffers_set:
destination[prefix + name] = buf if keep_vars else buf.detach()


def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs):
"""Rewrite this func to compat the convolutional kernel weights between
Expand All @@ -66,6 +47,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
Kernel weights in MMCV spconv has shape in (D,H,W,in_channel,out_channel) ,
while those in spcon2.x is in (out_channel,D,H,W,in_channel).
"""
version = local_metadata.get('version', None)
for hook in self._load_state_dict_pre_hooks.values():
hook(state_dict, prefix, local_metadata, strict, missing_keys,
unexpected_keys, error_msgs)
Expand All @@ -83,9 +65,10 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
# 0.3.* to version 0.4+
if len(param.shape) == 0 and len(input_param.shape) == 1:
input_param = input_param[0]
dims = [len(input_param.shape) - 1] + list(
range(len(input_param.shape) - 1))
input_param = input_param.permute(*dims)
if version != 2:
dims = [len(input_param.shape) - 1] + list(
range(len(input_param.shape) - 1))
input_param = input_param.permute(*dims)
if input_param.shape != param.shape:
# local shape should match the one in checkpoint
error_msgs.append(
Expand Down