From 607c4c64410eef6729e2c4efdaae15c4187337e8 Mon Sep 17 00:00:00 2001 From: ChaimZhu Date: Wed, 20 Oct 2021 15:52:15 +0800 Subject: [PATCH 1/4] add new converter --- .../convert_h3dnet_checkpoints.py | 175 ++++++++++++++++++ 1 file changed, 175 insertions(+) create mode 100644 tools/model_converters/convert_h3dnet_checkpoints.py diff --git a/tools/model_converters/convert_h3dnet_checkpoints.py b/tools/model_converters/convert_h3dnet_checkpoints.py new file mode 100644 index 0000000000..513b4b21ef --- /dev/null +++ b/tools/model_converters/convert_h3dnet_checkpoints.py @@ -0,0 +1,175 @@ +import argparse +import tempfile +import torch +from mmcv import Config +from mmcv.runner import load_state_dict + +from mmdet3d.models import build_detector + + +def parse_args(): + parser = argparse.ArgumentParser( + description='MMDet3D upgrade model version(before v0.6.0) of H3DNet') + parser.add_argument('checkpoint', help='checkpoint file') + parser.add_argument('--out', help='path of the output checkpoint file') + args = parser.parse_args() + return args + + +def parse_config(config_strings): + """Parse config from strings. + + Args: + config_strings (string): strings of model config. + + Returns: + Config: model config + """ + temp_file = tempfile.NamedTemporaryFile() + config_path = f'{temp_file.name}.py' + with open(config_path, 'w') as f: + f.write(config_strings) + + config = Config.fromfile(config_path) + + # Update backbone config + if 'pool_mod' in config.model.backbone.backbones: + config.model.backbone.backbones.pop('pool_mod') + + if 'sa_cfg' not in config.model.backbone: + config.model.backbone['sa_cfg'] = dict( + type='PointSAModule', + pool_mod='max', + use_xyz=True, + normalize_xyz=True) + + if 'type' not in config.model.rpn_head.vote_aggregation_cfg: + config.model.rpn_head.vote_aggregation_cfg['type'] = 'PointSAModule' + + # Update rpn_head config + if 'pred_layer_cfg' not in config.model.rpn_head: + config.model.rpn_head['pred_layer_cfg'] = dict( + in_channels=128, shared_conv_channels=(128, 128), bias=True) + + if 'feat_channels' in config.model.rpn_head: + config.model.rpn_head.pop('feat_channels') + + if 'vote_moudule_cfg' in config.model.rpn_head: + config.model.rpn_head['vote_module_cfg'] = config.model.rpn_head.pop( + 'vote_moudule_cfg') + + if config.model.rpn_head.vote_aggregation_cfg.use_xyz: + config.model.rpn_head.vote_aggregation_cfg.mlp_channels[0] -= 3 + + for cfg in config.model.roi_head.primitive_list: + cfg['vote_module_cfg'] = cfg.pop('vote_moudule_cfg') + cfg.vote_aggregation_cfg.mlp_channels[0] -= 3 + if 'type' not in cfg.vote_aggregation_cfg: + cfg.vote_aggregation_cfg['type'] = 'PointSAModule' + + if 'type' not in config.model.roi_head.bbox_head.suface_matching_cfg: + config.model.roi_head.bbox_head.suface_matching_cfg[ + 'type'] = 'PointSAModule' + + if config.model.roi_head.bbox_head.suface_matching_cfg.use_xyz: + config.model.roi_head.bbox_head.suface_matching_cfg.mlp_channels[ + 0] -= 3 + + if 'type' not in config.model.roi_head.bbox_head.line_matching_cfg: + config.model.roi_head.bbox_head.line_matching_cfg[ + 'type'] = 'PointSAModule' + + if config.model.roi_head.bbox_head.line_matching_cfg.use_xyz: + config.model.roi_head.bbox_head.line_matching_cfg.mlp_channels[0] -= 3 + + if 'proposal_module_cfg' in config.model.roi_head.bbox_head: + config.model.roi_head.bbox_head.pop('proposal_module_cfg') + + temp_file.close() + + return config + + +def main(): + """Convert keys in checkpoints for VoteNet. + + There can be some breaking changes during the development of mmdetection3d, + and this tool is used for upgrading checkpoints trained with old versions + (before v0.6.0) to the latest one. + """ + args = parse_args() + checkpoint = torch.load(args.checkpoint) + cfg = parse_config(checkpoint['meta']['config']) + # Build the model and load checkpoint + model = build_detector( + cfg.model, + train_cfg=cfg.get('train_cfg'), + test_cfg=cfg.get('test_cfg')) + orig_ckpt = checkpoint['state_dict'] + converted_ckpt = orig_ckpt.copy() + + if cfg['dataset_type'] == 'ScanNetDataset': + NUM_CLASSES = 18 + elif cfg['dataset_type'] == 'SUNRGBDDataset': + NUM_CLASSES = 10 + else: + raise NotImplementedError + + RENAME_PREFIX = { + 'rpn_head.conv_pred.0': 'rpn_head.conv_pred.shared_convs.layer0', + 'rpn_head.conv_pred.1': 'rpn_head.conv_pred.shared_convs.layer1' + } + + DEL_KEYS = [ + 'rpn_head.conv_pred.0.bn.num_batches_tracked', + 'rpn_head.conv_pred.1.bn.num_batches_tracked' + ] + + EXTRACT_KEYS = { + 'rpn_head.conv_pred.conv_cls.weight': + ('rpn_head.conv_pred.conv_out.weight', [(0, 2), (-NUM_CLASSES, -1)]), + 'rpn_head.conv_pred.conv_cls.bias': + ('rpn_head.conv_pred.conv_out.bias', [(0, 2), (-NUM_CLASSES, -1)]), + 'rpn_head.conv_pred.conv_reg.weight': + ('rpn_head.conv_pred.conv_out.weight', [(2, -NUM_CLASSES)]), + 'rpn_head.conv_pred.conv_reg.bias': + ('rpn_head.conv_pred.conv_out.bias', [(2, -NUM_CLASSES)]) + } + + # Delete some useless keys + for key in DEL_KEYS: + converted_ckpt.pop(key) + + # Rename keys with specific prefix + RENAME_KEYS = dict() + for old_key in converted_ckpt.keys(): + for rename_prefix in RENAME_PREFIX.keys(): + if rename_prefix in old_key: + new_key = old_key.replace(rename_prefix, + RENAME_PREFIX[rename_prefix]) + RENAME_KEYS[new_key] = old_key + for new_key, old_key in RENAME_KEYS.items(): + converted_ckpt[new_key] = converted_ckpt.pop(old_key) + + # Extract weights and rename the keys + for new_key, (old_key, indices) in EXTRACT_KEYS.items(): + cur_layers = orig_ckpt[old_key] + converted_layers = [] + for (start, end) in indices: + if end != -1: + converted_layers.append(cur_layers[start:end]) + else: + converted_layers.append(cur_layers[start:]) + converted_layers = torch.cat(converted_layers, 0) + converted_ckpt[new_key] = converted_layers + if old_key in converted_ckpt.keys(): + converted_ckpt.pop(old_key) + + # Check the converted checkpoint by loading to the model + load_state_dict(model, converted_ckpt, strict=True) + checkpoint['state_dict'] = converted_ckpt + torch.save(checkpoint, args.out) + + +if __name__ == '__main__': + main() From 71387744925698ac8231e0a9bbb4e40b155e763b Mon Sep 17 00:00:00 2001 From: ChaimZhu Date: Tue, 26 Oct 2021 20:48:17 +0800 Subject: [PATCH 2/4] add compac doc --- configs/h3dnet/README.md | 8 ++++++++ docs/compatibility.md | 4 ++++ 2 files changed, 12 insertions(+) diff --git a/configs/h3dnet/README.md b/configs/h3dnet/README.md index 8fe8307751..3eceaf7fb3 100644 --- a/configs/h3dnet/README.md +++ b/configs/h3dnet/README.md @@ -22,3 +22,11 @@ We implement H3DNet and provide the result and checkpoints on ScanNet datasets. | Backbone | Lr schd | Mem (GB) | Inf time (fps) | AP@0.25 |AP@0.5| Download | | :---------: | :-----: | :------: | :------------: | :----: |:----: | :------: | | [MultiBackbone](./h3dnet_3x8_scannet-3d-18class.py) | 3x |7.9||66.43|48.01|[model](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/h3dnet/h3dnet_scannet-3d-18class/h3dnet_scannet-3d-18class_20200830_000136-02e36246.pth) | [log](https://download.openmmlab.com/mmdetection3d/v0.1.0_models/h3dnet/h3dnet_scannet-3d-18class/h3dnet_scannet-3d-18class_20200830_000136.log.json) | + +**Notice**: If your current mmdetection3d version >= 0.6.0, and you are using the checkpoints downloaded from the above links or using checkpoints trained with mmdetection3d version < 0.6.0, the checkpoints have to be first converted via [tools/model_converters/convert_h3dnet_checkpoints.py](../../tools/model_converters/convert_h3dnet_checkpoints.py): + +``` +python ./tools/model_converters/convert_h3dnet_checkpoints.py ${ORIGINAL_CHECKPOINT_PATH} --out=${NEW_CHECKPOINT_PATH} +``` + +Then you can use the converted checkpoints following [getting_started.md](../../docs/getting_started.md). diff --git a/docs/compatibility.md b/docs/compatibility.md index 5cc8dcf84a..0bfcc3f3f2 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -78,3 +78,7 @@ Please refer to the SUNRGBD [README.md](https://github.com/open-mmlab/mmdetectio ### VoteNet model structure update In MMDetection 0.6.0, we updated the model structure of VoteNet, therefore model checkpoints generated by MMDetection < 0.6.0 should be first converted to a format compatible with the latest VoteNet structure via this [script](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/model_converters/convert_votenet_checkpoints.py). For more details, please refer to the VoteNet [README.md](https://github.com/open-mmlab/mmdetection3d/tree/master/configs/votenet/README.md/) + +### H3DNet model structure update + +In MMDetection 0.6.0, we updated the model structure of H3DNet, therefore model checkpoints generated by MMDetection < 0.6.0 should be first converted to a format compatible with the latest H3DNet structure via this [script](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/model_converters/convert_h3dnet_checkpoints.py). For more details, please refer to the VoteNet [README.md](https://github.com/open-mmlab/mmdetection3d/tree/master/configs/h3dnet/README.md/) From bbe12cc65f71d4c3185763645a75e0d7ddef2ee1 Mon Sep 17 00:00:00 2001 From: ChaimZhu Date: Wed, 27 Oct 2021 14:13:16 +0800 Subject: [PATCH 3/4] merge compac doc --- docs/compatibility.md | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/docs/compatibility.md b/docs/compatibility.md index 0bfcc3f3f2..96e347fb46 100644 --- a/docs/compatibility.md +++ b/docs/compatibility.md @@ -75,10 +75,6 @@ Please refer to the SUNRGBD [README.md](https://github.com/open-mmlab/mmdetectio ## 0.6.0 -### VoteNet model structure update +### VoteNet and H3DNet model structure update -In MMDetection 0.6.0, we updated the model structure of VoteNet, therefore model checkpoints generated by MMDetection < 0.6.0 should be first converted to a format compatible with the latest VoteNet structure via this [script](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/model_converters/convert_votenet_checkpoints.py). For more details, please refer to the VoteNet [README.md](https://github.com/open-mmlab/mmdetection3d/tree/master/configs/votenet/README.md/) - -### H3DNet model structure update - -In MMDetection 0.6.0, we updated the model structure of H3DNet, therefore model checkpoints generated by MMDetection < 0.6.0 should be first converted to a format compatible with the latest H3DNet structure via this [script](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/model_converters/convert_h3dnet_checkpoints.py). For more details, please refer to the VoteNet [README.md](https://github.com/open-mmlab/mmdetection3d/tree/master/configs/h3dnet/README.md/) +In MMDetection 0.6.0, we updated the model structures of VoteNet and H3DNet, therefore model checkpoints generated by MMDetection < 0.6.0 should be first converted to a format compatible with the latest structures via [convert_votenet_checkpoints.py](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/model_converters/convert_votenet_checkpoints.py) and [convert_h3dnet_checkpoints.py](https://github.com/open-mmlab/mmdetection3d/blob/master/tools/model_converters/convert_h3dnet_checkpoints.py) . For more details, please refer to the VoteNet [README.md](https://github.com/open-mmlab/mmdetection3d/tree/master/configs/votenet/README.md/) and H3DNet [README.md](https://github.com/open-mmlab/mmdetection3d/tree/master/configs/h3dnet/README.md/). From d929a76b50b186dd48385f4151fff34b6b9b4459 Mon Sep 17 00:00:00 2001 From: ChaimZhu Date: Wed, 27 Oct 2021 21:51:54 +0800 Subject: [PATCH 4/4] add header --- tools/model_converters/convert_h3dnet_checkpoints.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tools/model_converters/convert_h3dnet_checkpoints.py b/tools/model_converters/convert_h3dnet_checkpoints.py index 513b4b21ef..9368a87182 100644 --- a/tools/model_converters/convert_h3dnet_checkpoints.py +++ b/tools/model_converters/convert_h3dnet_checkpoints.py @@ -1,3 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. import argparse import tempfile import torch