From 0fcf834eb70479aba7e6d486d929cef6510b4584 Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Mon, 26 Dec 2022 11:51:51 +0800 Subject: [PATCH 01/18] init centerformer in projects --- mmdet3d/apis/inference.py | 3 +- mmdet3d/datasets/transforms/loading.py | 8 + mmdet3d/datasets/transforms/transforms_3d.py | 4 + mmdet3d/engine/hooks/__init__.py | 5 +- .../hooks/disable_object_sample_hook.py | 55 + mmdet3d/models/layers/sparse_block.py | 5 +- mmdet3d/models/layers/spconv/__init__.py | 4 +- .../centerformer/centerformer/__init__.py | 9 + .../centerformer/centerformer/centerformer.py | 191 ++++ .../centerformer/centerformer_head.py | 640 ++++++++++++ projects/centerformer/centerformer/losses.py | 56 + .../centerformer/rpn_transformer.py | 975 ++++++++++++++++++ .../centerformer/utils/__init__.py | 11 + .../centerformer/utils/attention.py | 42 + .../centerformer/utils/bbox_ops.py | 76 ++ .../utils/multi_scale_deform_attn.py | 203 ++++ .../centerformer/utils/sparse_block.py | 88 ++ .../centerformer/utils/transformer.py | 421 ++++++++ .../configs/centerform_voxel01.py | 309 ++++++ 19 files changed, 3099 insertions(+), 6 deletions(-) create mode 100644 mmdet3d/engine/hooks/disable_object_sample_hook.py create mode 100644 projects/centerformer/centerformer/__init__.py create mode 100644 projects/centerformer/centerformer/centerformer.py create mode 100644 projects/centerformer/centerformer/centerformer_head.py create mode 100644 projects/centerformer/centerformer/losses.py create mode 100644 projects/centerformer/centerformer/rpn_transformer.py create mode 100644 projects/centerformer/centerformer/utils/__init__.py create mode 100644 projects/centerformer/centerformer/utils/attention.py create mode 100644 projects/centerformer/centerformer/utils/bbox_ops.py create mode 100644 projects/centerformer/centerformer/utils/multi_scale_deform_attn.py create mode 100644 projects/centerformer/centerformer/utils/sparse_block.py create mode 100644 projects/centerformer/centerformer/utils/transformer.py create mode 100644 projects/centerformer/configs/centerform_voxel01.py diff --git a/mmdet3d/apis/inference.py b/mmdet3d/apis/inference.py index 2273aa95a4..f9726e1511 100644 --- a/mmdet3d/apis/inference.py +++ b/mmdet3d/apis/inference.py @@ -67,7 +67,8 @@ def init_model(config: Union[str, Path, Config], if checkpoint is not None: checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') - dataset_meta = checkpoint['meta'].get('dataset_meta', None) + if 'meta' in checkpoint: + dataset_meta = checkpoint['meta'].get('dataset_meta', None) # save the dataset_meta in the model for convenience if 'dataset_meta' in checkpoint.get('meta', {}): # mmdet3d 1.x diff --git a/mmdet3d/datasets/transforms/loading.py b/mmdet3d/datasets/transforms/loading.py index 718ee9eb78..a4d3f24937 100644 --- a/mmdet3d/datasets/transforms/loading.py +++ b/mmdet3d/datasets/transforms/loading.py @@ -532,6 +532,8 @@ class LoadPointsFromFile(BaseTransform): or use_dim=[0, 1, 2, 3] to use the intensity dimension. shift_height (bool): Whether to use shifted height. Defaults to False. use_color (bool): Whether to use color features. Defaults to False. + norm_intensity (bool): Whether to normlize the intensity. Defaults to + False. file_client_args (dict): Arguments to instantiate a FileClient. See :class:`mmengine.fileio.FileClient` for details. Defaults to dict(backend='disk'). @@ -544,6 +546,7 @@ def __init__( use_dim: Union[int, List[int]] = [0, 1, 2], shift_height: bool = False, use_color: bool = False, + norm_intensity: bool = False, file_client_args: dict = dict(backend='disk') ) -> None: self.shift_height = shift_height @@ -557,6 +560,7 @@ def __init__( self.coord_type = coord_type self.load_dim = load_dim self.use_dim = use_dim + self.norm_intensity = norm_intensity self.file_client_args = file_client_args.copy() self.file_client = None @@ -599,6 +603,10 @@ def transform(self, results: dict) -> dict: points = self._load_points(pts_file_path) points = points.reshape(-1, self.load_dim) points = points[:, self.use_dim] + if self.norm_intensity: + assert len(self.use_dim) >= 4, \ + f'When using intensity norm, expect used dimensions >= 4, got {len(self.use_dim)}' # noqa: E501 + points[:, 3] = np.tanh(points[:, 3]) attribute_dims = None if self.shift_height: diff --git a/mmdet3d/datasets/transforms/transforms_3d.py b/mmdet3d/datasets/transforms/transforms_3d.py index 668e1fb15b..495f0c1e2d 100644 --- a/mmdet3d/datasets/transforms/transforms_3d.py +++ b/mmdet3d/datasets/transforms/transforms_3d.py @@ -359,6 +359,7 @@ def __init__(self, db_sampler['type'] = 'DataBaseSampler' self.db_sampler = TRANSFORMS.build(db_sampler) self.use_ground_plane = use_ground_plane + self.disabled = False @staticmethod def remove_points_in_boxes(points: BasePoints, @@ -387,6 +388,9 @@ def transform(self, input_dict: dict) -> dict: 'points', 'gt_bboxes_3d', 'gt_labels_3d' keys are updated in the result dict. """ + if self.disabled: + return input_dict + gt_bboxes_3d = input_dict['gt_bboxes_3d'] gt_labels_3d = input_dict['gt_labels_3d'] diff --git a/mmdet3d/engine/hooks/__init__.py b/mmdet3d/engine/hooks/__init__.py index 1d47e4d549..578f173d41 100644 --- a/mmdet3d/engine/hooks/__init__.py +++ b/mmdet3d/engine/hooks/__init__.py @@ -1,5 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. from .benchmark_hook import BenchmarkHook +from .disable_object_sample_hook import DisableObjectSampleHook from .visualization_hook import Det3DVisualizationHook -__all__ = ['Det3DVisualizationHook', 'BenchmarkHook'] +__all__ = [ + 'Det3DVisualizationHook', 'BenchmarkHook', 'DisableObjectSampleHook' +] diff --git a/mmdet3d/engine/hooks/disable_object_sample_hook.py b/mmdet3d/engine/hooks/disable_object_sample_hook.py new file mode 100644 index 0000000000..293e18d105 --- /dev/null +++ b/mmdet3d/engine/hooks/disable_object_sample_hook.py @@ -0,0 +1,55 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmengine.hooks import Hook +from mmengine.model import is_model_wrapper +from mmengine.runner import Runner + +from mmdet3d.datasets.transforms import ObjectSample +from mmdet3d.registry import HOOKS + + +@HOOKS.register_module() +class DisableObjectSampleHook(Hook): + """The hook of disabling augmentations during training. + + Args: + num_last_epochs (int): The number of latter epochs in the end of the + training to close the data augmentation. Default: 15. + skip_type_keys (list[str], optional): Sequence of type string to be + skipped in the data pipeline. Default: ('ObjectSample') + """ + + def __init__(self, disable_after_epoch: int = 15): + self.disable_after_epoch = disable_after_epoch + self._restart_dataloader = False + + def before_train_epoch(self, runner: Runner): + """Close augmentation. + + Args: + runner (Runner): The runner. + """ + epoch = runner.epoch + train_loader = runner.train_dataloader + model = runner.model + # TODO: refactor after mmengine using model wrapper + if is_model_wrapper(model): + model = model.module + if epoch == self.disable_after_epoch: + runner.logger.info('Disable ObjectSample') + for transform in runner.train_dataloader.dataset.pipeline.transforms: + if isinstance(transform, ObjectSample): + assert hasattr(transform, 'disabled') + transform.disabled = True + # The dataset pipeline cannot be updated when persistent_workers + # is True, so we need to force the dataloader's multi-process + # restart. This is a very hacky approach. + if hasattr(train_loader, 'persistent_workers' + ) and train_loader.persistent_workers is True: + train_loader._DataLoader__initialized = False + train_loader._iterator = None + self._restart_dataloader = True + else: + # Once the restart is complete, we need to restore + # the initialization flag. + if self._restart_dataloader: + train_loader._DataLoader__initialized = True diff --git a/mmdet3d/models/layers/sparse_block.py b/mmdet3d/models/layers/sparse_block.py index 14fc4deeda..3d3dd66274 100644 --- a/mmdet3d/models/layers/sparse_block.py +++ b/mmdet3d/models/layers/sparse_block.py @@ -2,10 +2,13 @@ from typing import Tuple, Union from mmcv.cnn import build_conv_layer, build_norm_layer -from mmdet.models.backbones.resnet import BasicBlock, Bottleneck +# from mmdet.models.backbones.resnet import BasicBlock, Bottleneck +from mmdet.models.backbones.resnet import Bottleneck from torch import nn from mmdet3d.utils import OptConfigType +from projects.centerformer.centerformer.utils import \ + BasicBlockBias as BasicBlock # noqa: E501 from .spconv import IS_SPCONV2_AVAILABLE if IS_SPCONV2_AVAILABLE: diff --git a/mmdet3d/models/layers/spconv/__init__.py b/mmdet3d/models/layers/spconv/__init__.py index 98db14a7ac..37b533e0c6 100644 --- a/mmdet3d/models/layers/spconv/__init__.py +++ b/mmdet3d/models/layers/spconv/__init__.py @@ -6,9 +6,7 @@ except ImportError: IS_SPCONV2_AVAILABLE = False else: - if hasattr(spconv, - '__version__') and spconv.__version__ >= '2.0.0' and hasattr( - spconv, 'pytorch'): + if hasattr(spconv, '__version__') and spconv.__version__ >= '2.0.0': IS_SPCONV2_AVAILABLE = register_spconv2() else: IS_SPCONV2_AVAILABLE = False diff --git a/projects/centerformer/centerformer/__init__.py b/projects/centerformer/centerformer/__init__.py new file mode 100644 index 0000000000..5893950e9b --- /dev/null +++ b/projects/centerformer/centerformer/__init__.py @@ -0,0 +1,9 @@ +from .centerformer import CenterFormer +from .centerformer_head import CenterHeadIoU_1d +from .losses import FastFocalLoss +from .rpn_transformer import RPN_transformer_deformable + +__all__ = [ + 'CenterFormer', 'RPN_transformer_deformable', 'CenterHeadIoU_1d', + 'FastFocalLoss' +] diff --git a/projects/centerformer/centerformer/centerformer.py b/projects/centerformer/centerformer/centerformer.py new file mode 100644 index 0000000000..5107b92993 --- /dev/null +++ b/projects/centerformer/centerformer/centerformer.py @@ -0,0 +1,191 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional + +import torch +from torch import Tensor +from torch.nn.modules.batchnorm import _BatchNorm + +from mmdet3d.models.detectors import Base3DDetector +from mmdet3d.registry import MODELS +from mmdet3d.structures import Det3DDataSample + + +@MODELS.register_module() +class CenterFormer(Base3DDetector): + """Base class of center-based 3D detector. + + Args: + voxel_encoder (dict, optional): Point voxelization + encoder layer. Defaults to None. + middle_encoder (dict, optional): Middle encoder layer + of points cloud modality. Defaults to None. + pts_fusion_layer (dict, optional): Fusion layer. + Defaults to None. + backbone (dict, optional): Backbone of extracting + points features. Defaults to None. + neck (dict, optional): Neck of extracting + points features. Defaults to None. + bbox_head (dict, optional): Bboxes head of + point cloud modality. Defaults to None. + train_cfg (dict, optional): Train config of model. + Defaults to None. + test_cfg (dict, optional): Train config of model. + Defaults to None. + init_cfg (dict, optional): Initialize config of + model. Defaults to None. + data_preprocessor (dict or ConfigDict, optional): The pre-process + config of :class:`Det3DDataPreprocessor`. Defaults to None. + """ + + def __init__(self, + voxel_encoder: Optional[dict] = None, + middle_encoder: Optional[dict] = None, + backbone: Optional[dict] = None, + neck: Optional[dict] = None, + bbox_head: Optional[dict] = None, + train_cfg: Optional[dict] = None, + test_cfg: Optional[dict] = None, + init_cfg: Optional[dict] = None, + data_preprocessor: Optional[dict] = None, + **kwargs): + super(CenterFormer, self).__init__( + init_cfg=init_cfg, data_preprocessor=data_preprocessor, **kwargs) + + if voxel_encoder: + self.voxel_encoder = MODELS.build(voxel_encoder) + if middle_encoder: + self.middle_encoder = MODELS.build(middle_encoder) + if backbone: + backbone.update(train_cfg=train_cfg, test_cfg=test_cfg) + self.backbone = MODELS.build(backbone) + if neck is not None: + self.neck = MODELS.build(neck) + if bbox_head: + bbox_head.update(train_cfg=train_cfg, test_cfg=test_cfg) + self.bbox_head = MODELS.build(bbox_head) + + self.train_cfg = train_cfg + self.test_cfg = test_cfg + + def init_weights(self): + for m in self.modules(): + if isinstance(m, _BatchNorm): + torch.nn.init.uniform_(m.weight) + + @property + def with_bbox(self): + """bool: Whether the detector has a 3D box head.""" + return hasattr(self, 'bbox_head') and self.bbox_head is not None + + @property + def with_backbone(self): + """bool: Whether the detector has a 3D backbone.""" + return hasattr(self, 'backbone') and self.backbone is not None + + @property + def with_fusion(self): + """bool: Whether the detector has a fusion layer.""" + return hasattr(self, + 'pts_fusion_layer') and self.fusion_layer is not None + + @property + def with_neck(self): + """bool: Whether the detector has a neck in 3D detector branch.""" + return hasattr(self, 'neck') and self.neck is not None + + @property + def with_voxel_encoder(self): + """bool: Whether the detector has a voxel encoder.""" + return hasattr(self, + 'voxel_encoder') and self.voxel_encoder is not None + + @property + def with_middle_encoder(self): + """bool: Whether the detector has a middle encoder.""" + return hasattr(self, + 'middle_encoder') and self.middle_encoder is not None + + def _forward(self): + pass + + def extract_feat(self, batch_inputs_dict: dict, + batch_input_metas: List[dict]) -> tuple: + """Extract features from images and points. + Args: + batch_inputs_dict (dict): Dict of batch inputs. It + contains + - points (List[tensor]): Point cloud of multiple inputs. + - imgs (tensor): Image tensor with shape (B, C, H, W). + batch_input_metas (list[dict]): Meta information of multiple inputs + in a batch. + Returns: + tuple: Two elements in tuple arrange as + image features and point cloud features. + """ + voxel_dict = batch_inputs_dict.get('voxels', None) + voxel_features, feature_coors = self.voxel_encoder( + voxel_dict['voxels'], voxel_dict['coors']) + batch_size = voxel_dict['coors'][-1, 0].item() + 1 + x = self.middle_encoder(voxel_features, feature_coors, batch_size) + + return x + + def loss(self, batch_inputs_dict: Dict[List, torch.Tensor], + batch_data_samples: List[Det3DDataSample], + **kwargs) -> List[Det3DDataSample]: + """ + Args: + batch_inputs_dict (dict): The model input dict which include + 'points' and `imgs` keys. + - points (list[torch.Tensor]): Point cloud of each sample. + - imgs (torch.Tensor): Tensor of batch images, has shape + (B, C, H ,W) + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance_3d`, . + Returns: + dict[str, Tensor]: A dictionary of loss components. + """ + + batch_input_metas = [item.metainfo for item in batch_data_samples] + pts_feats = self.extract_feat(batch_inputs_dict, batch_input_metas) + preds, batch_tatgets = self.backbone(pts_feats, batch_data_samples) + preds = self.bbox_head(preds) + losses = dict() + losses.update(self.bbox_head.loss(preds, batch_tatgets)) + return losses + # return self.bbox_head.predict(preds, batch_tatgets) + + def predict(self, batch_inputs_dict: Dict[str, Optional[Tensor]], + batch_data_samples: List[Det3DDataSample], + **kwargs) -> List[Det3DDataSample]: + """Forward of testing. + Args: + batch_inputs_dict (dict): The model input dict which include + 'points' keys. + - points (list[torch.Tensor]): Point cloud of each sample. + batch_data_samples (List[:obj:`Det3DDataSample`]): The Data + Samples. It usually includes information such as + `gt_instance_3d`. + Returns: + list[:obj:`Det3DDataSample`]: Detection results of the + input sample. Each Det3DDataSample usually contain + 'pred_instances_3d'. And the ``pred_instances_3d`` usually + contains following keys. + - scores_3d (Tensor): Classification scores, has a shape + (num_instances, ) + - labels_3d (Tensor): Labels of bboxes, has a shape + (num_instances, ). + - bbox_3d (:obj:`BaseInstance3DBoxes`): Prediction of bboxes, + contains a tensor with shape (num_instances, 7). + """ + batch_input_metas = [item.metainfo for item in batch_data_samples] + pts_feats = self.extract_feat(batch_inputs_dict, batch_input_metas) + preds, _ = self.backbone(pts_feats, batch_data_samples) + + preds = self.bbox_head(preds) + results_list_3d = self.bbox_head.predict(preds, batch_input_metas) + + detsamples = self.add_pred_to_datasample(batch_data_samples, + results_list_3d) + return detsamples diff --git a/projects/centerformer/centerformer/centerformer_head.py b/projects/centerformer/centerformer/centerformer_head.py new file mode 100644 index 0000000000..dc6ea9feb2 --- /dev/null +++ b/projects/centerformer/centerformer/centerformer_head.py @@ -0,0 +1,640 @@ +# ------------------------------------------------------------------------------ +# Portions of this code are from +# det3d (https://github.com/poodarchu/Det3D/tree/56402d4761a5b73acd23080f537599b0888cce07) # noqa +# Copyright (c) 2019 朱本金 +# Licensed under the MIT License +# ------------------------------------------------------------------------------ + +import copy +import logging + +import numpy as np +import torch +from mmcv.cnn import build_norm_layer +from mmengine.logging import print_log +from mmengine.model import kaiming_init +from mmengine.structures import InstanceData +from torch import nn + +from mmdet3d.models.layers import circle_nms, nms_bev +from mmdet3d.registry import MODELS +from mmdet3d.structures import bbox_overlaps_3d +from .losses import FastFocalLoss +from .utils import boxes_iou3d_gpu_pcdet, rotate_nms_pcdet + + +class SepHead(nn.Module): + + def __init__( + self, + in_channels, + heads, + head_conv=64, + final_kernel=1, + bn=False, + init_bias=-2.19, + norm_cfg=dict(type='BN', eps=1e-3, momentum=0.01), + **kwargs, + ): + super(SepHead, self).__init__(**kwargs) + + self.heads = heads + for head in self.heads: + classes, num_conv = self.heads[head] + + fc = [] + for i in range(num_conv - 1): + fc.append( + nn.Conv1d( + in_channels, + head_conv, + kernel_size=final_kernel, + stride=1, + padding=final_kernel // 2, + bias=True, + )) + if bn: + fc.append(build_norm_layer(norm_cfg, head_conv)[1]) + fc.append(nn.ReLU()) + + fc.append( + nn.Conv1d( + head_conv, + classes, + kernel_size=final_kernel, + stride=1, + padding=final_kernel // 2, + bias=True, + )) + + if 'hm' in head: + fc[-1].bias.data.fill_(init_bias) + else: + for m in fc: + if isinstance(m, nn.Conv1d): + kaiming_init(m) + + fc = nn.Sequential(*fc) + self.__setattr__(head, fc) + + def forward(self, x, y): + for head in self.heads: + x[head] = self.__getattr__(head)(y) + + return x + + +@MODELS.register_module() +class CenterHeadIoU_1d(nn.Module): + + def __init__(self, + in_channels=[ + 128, + ], + tasks=[], + weight=0.25, + iou_weight=1, + corner_weight=1, + code_weights=[], + common_heads=dict(), + logger=None, + init_bias=-2.19, + share_conv_channel=64, + assign_label_window_size=1, + iou_loss=False, + corner_loss=False, + iou_factor=[1, 1, 4], + norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01), + bbox_code_size=7, + test_cfg=None, + **kawrgs): + super(CenterHeadIoU_1d, self).__init__() + + num_classes = [len(t['class_names']) for t in tasks] + self.class_names = [t['class_names'] for t in tasks] + self.code_weights = code_weights + self.bbox_code_size = 7 + self.weight = weight # weight between hm loss and loc loss + self.iou_weight = iou_weight + self.corner_weight = corner_weight + self.iou_factor = iou_factor + + self.in_channels = in_channels + self.num_classes = num_classes + self.test_cfg = test_cfg + + self.crit = FastFocalLoss(assign_label_window_size) + self.crit_reg = torch.nn.L1Loss(reduction='none') + self.use_iou_loss = iou_loss + if self.use_iou_loss: + self.crit_iou = torch.nn.SmoothL1Loss(reduction='none') + self.corner_loss = corner_loss + if self.corner_loss: + self.corner_crit = torch.nn.MSELoss(reduction='none') + + self.box_n_dim = 9 if 'vel' in common_heads else 7 + self.use_direction_classifier = False + + if not logger: + logger = logging.getLogger('CenterHeadIoU_1d') + self.logger = logger + + logger.info(f'num_classes: {num_classes}') + + # a shared convolution + self.shared_conv = nn.Sequential( + nn.Conv1d( + in_channels, share_conv_channel, kernel_size=1, bias=True), + build_norm_layer(norm_cfg, share_conv_channel)[1], + nn.ReLU(inplace=True), + ) + + self.tasks = nn.ModuleList() + print_log(f'Use HM Bias: {init_bias}', 'current') + + for num_cls in num_classes: + heads = copy.deepcopy(common_heads) + self.tasks.append( + SepHead( + share_conv_channel, + heads, + bn=True, + init_bias=init_bias, + final_kernel=1, + norm_cfg=norm_cfg)) + + logger.info('Finish CenterHeadIoU Initialization') + + def forward(self, x, *kwargs): + ret_dicts = [] + + y = self.shared_conv(x['ct_feat'].float()) + + for task in self.tasks: + ret_dicts.append(task(x, y)) + + return ret_dicts + + def _sigmoid(self, x): + y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4) + return y + + def loss(self, preds_dicts, example, **kwargs): + losses = {} + for task_id, preds_dict in enumerate(preds_dicts): + # heatmap focal loss + hm_loss = self.crit( + preds_dict['hm'], + example['hm'][task_id], + example['ind'][task_id], + example['mask'][task_id], + example['cat'][task_id], + ) + + target_box = example['anno_box'][task_id] + + if self.corner_loss: + corner_loss = self.corner_crit(preds_dict['corner_hm'], + example['corners'][task_id]) + corner_mask = (example['corners'][task_id] > 0).to(corner_loss) + corner_loss = (corner_loss * corner_mask).sum() / ( + corner_mask.sum() + 1e-4) + losses.update({ + f'{task_id}_corner_loss': + corner_loss * self.corner_weight + }) + + # reconstruct the anno_box from multiple reg heads + if 'vel' in preds_dict: + preds_dict['anno_box'] = torch.cat( + ( + preds_dict['reg'], + preds_dict['height'], + preds_dict['dim'], + preds_dict['vel'], + preds_dict['rot'], + ), + dim=1, + ) + else: + preds_dict['anno_box'] = torch.cat( + ( + preds_dict['reg'], + preds_dict['height'], + preds_dict['dim'], + preds_dict['rot'], + ), + dim=1, + ) + target_box = target_box[..., [0, 1, 2, 3, 4, 5, -2, + -1]] # remove vel target + + # Regression loss for dimension, offset, height, rotation + # get corresponding gt box # B, 500 + target_box, selected_mask, selected_cls = get_corresponding_box( + preds_dict['order'], + example['ind'][task_id], + example['mask'][task_id], + example['cat'][task_id], + target_box, + ) + mask = selected_mask.float().unsqueeze(2) + + weights = self.code_weights + + box_loss = self.crit_reg( + preds_dict['anno_box'].transpose(1, 2) * mask, + target_box * mask) + box_loss = box_loss / (mask.sum() + 1e-4) + box_loss = box_loss.transpose(2, 0).sum(dim=2).sum(dim=1) + + loc_loss = (box_loss * box_loss.new_tensor(weights)).sum() + + if self.use_iou_loss: + with torch.no_grad(): + preds_box = get_box( + preds_dict['anno_box'], + preds_dict['order'], + self.test_cfg, + preds_dict['hm'].shape[2], + preds_dict['hm'].shape[3], + ) + cur_gt = get_box_gt( + target_box, + preds_dict['order'], + self.test_cfg, + preds_dict['hm'].shape[2], + preds_dict['hm'].shape[3], + ) + # iou_targets = bbox_overlaps_3d( + # preds_box.reshape(-1, 7), + # cur_gt.reshape(-1, 7), + # coordinate='lidar')[ + # range(preds_box.reshape(-1, 7).shape[0]), + # range(cur_gt.reshape(-1, 7).shape[0])] + + preds_box[:, :, 2] += preds_box[:, :, 5] / 2 + cur_gt[:, :, 2] += cur_gt[:, :, 5] / 2 + iou_targets = boxes_iou3d_gpu_pcdet( + preds_box.reshape(-1, 7), cur_gt.reshape( + -1, 7))[range(preds_box.reshape(-1, 7).shape[0]), + range(cur_gt.reshape(-1, 7).shape[0])] + iou_targets[torch.isnan(iou_targets)] = 0 + iou_targets = 2 * iou_targets - 1 + iou_loss = self.crit_iou(preds_dict['iou'].reshape(-1), + iou_targets) * mask.reshape(-1) + iou_loss = iou_loss.sum() / (mask.sum() + 1e-4) + + losses.update( + {f'{task_id}_iou_loss': iou_loss * self.iou_weight}) + + # loss = hm_loss + self.weight * loc_loss + # if self.use_iou_loss: + # loss = loss + self.iou_weight * iou_loss + # if self.corner_loss: + # loss = loss + self.corner_weight * corner_loss + losses.update({ + f'{task_id}_hm_loss': hm_loss, + f'{task_id}_loc_loss': loc_loss * self.weight, + # 'loc_loss_elem': box_loss, + # 'num_positive': example['mask'][task_id].float().sum(), + }) + + # """convert batch-key to key-batch""" + # rets_merged = defaultdict(list) + # for ret in rets: + # for k, v in ret.items(): + # rets_merged[k].append(v) + + return losses + + @torch.no_grad() + def predict(self, preds_dicts, batch_input_metas, **kwargs): + """decode, nms, then return the detection result. + + Additionally support double flip testing + """ + # get loss info + rets = [] + metas = [] + + post_center_range = self.test_cfg.post_center_limit_range + if len(post_center_range) > 0: + post_center_range = torch.tensor( + post_center_range, + dtype=preds_dicts[0]['scores'].dtype, + device=preds_dicts[0]['scores'].device, + ) + + for task_id, preds_dict in enumerate(preds_dicts): + # convert B C N to B N C + for key, val in preds_dict.items(): + if torch.is_tensor(preds_dict[key]): + if len(preds_dict[key].shape) == 3: + preds_dict[key] = val.permute(0, 2, 1).contiguous() + + batch_size = preds_dict['scores'].shape[0] + + # if 'metadata' not in example or len(example['metadata']) == 0: + # meta_list = [None] * batch_size + # else: + # meta_list = example['metadata'] + + batch_score = preds_dict['scores'] + batch_label = preds_dict['labels'] + batch_mask = preds_dict['mask'] + if self.use_iou_loss: + batch_iou = preds_dict['iou'].squeeze(2) + else: + batch_iou = None + if 'corner_hm' in preds_dict: + batch_corner_hm = preds_dict['corner_hm'] + else: + batch_corner_hm = None + + batch_dim = torch.exp(preds_dict['dim']) + + batch_rots = preds_dict['rot'][..., 0:1] + batch_rotc = preds_dict['rot'][..., 1:2] + + batch_reg = preds_dict['reg'] + batch_hei = preds_dict['height'] + batch_rot = torch.atan2(batch_rots, batch_rotc) + if self.use_iou_loss: + batch_iou = (batch_iou + 1) * 0.5 + batch_iou = torch.clamp(batch_iou, min=0.0, max=1.0) + + batch, _, H, W = preds_dict['hm'].size() + + ys, xs = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)]) + ys = ys.view(1, H, W).repeat(batch, 1, 1).to(batch_score) + xs = xs.view(1, H, W).repeat(batch, 1, 1).to(batch_score) + + obj_num = preds_dict['order'].shape[1] + batch_id = np.indices((batch, obj_num))[0] + batch_id = torch.from_numpy(batch_id).to(preds_dict['order']) + + xs = ( + xs.view(batch, -1, 1)[batch_id, preds_dict['order']] + + batch_reg[:, :, 0:1]) + ys = ( + ys.view(batch, -1, 1)[batch_id, preds_dict['order']] + + batch_reg[:, :, 1:2]) + + xs = ( + xs * self.test_cfg.out_size_factor * + self.test_cfg.voxel_size[0] + self.test_cfg.pc_range[0]) + ys = ( + ys * self.test_cfg.out_size_factor * + self.test_cfg.voxel_size[1] + self.test_cfg.pc_range[1]) + + if 'vel' in preds_dict: + batch_vel = preds_dict['vel'] + batch_box_preds = torch.cat( + [xs, ys, batch_hei, batch_dim, batch_vel, batch_rot], + dim=2) + else: + batch_box_preds = torch.cat( + [xs, ys, batch_hei, batch_dim, batch_rot], dim=2) + + # metas.append(meta_list) + + if self.test_cfg.get('per_class_nms', False): + pass + else: + rets.append( + self.post_processing( + batch_input_metas, + batch_box_preds, + batch_score, + batch_label, + self.test_cfg, + post_center_range, + task_id, + batch_mask, + batch_iou, + )) + + # Merge branches results + ret_list = [] + num_samples = len(rets[0]) + + ret_list = [] + for i in range(num_samples): + temp_instances = InstanceData() + for k in rets[0][i].keys(): + if k == 'bboxes': + bboxes = torch.cat([ret[i][k] for ret in rets]) + bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5 + # The original CenterFormer model predict (..., w,l,h) + # Note that this is used to align the precision of + # converted model + # bboxes[:, 4], bboxes[:, 3] = bboxes[:, 3].clone( + # ), bboxes[:, 4].clone() + # bboxes[:, 6] = -bboxes[:, 6] - np.pi / 2 + bboxes = batch_input_metas[i]['box_type_3d']( + bboxes, self.bbox_code_size) + elif k == 'labels': + flag = 0 + for j, num_class in enumerate(self.num_classes): + rets[j][i][k] += flag + flag += num_class + labels = torch.cat([ret[i][k] for ret in rets]) + elif k == 'scores': + scores = torch.cat([ret[i][k] for ret in rets]) + # ret['metadata'] = metas[0][i] + + temp_instances.bboxes_3d = bboxes + temp_instances.scores_3d = scores + temp_instances.labels_3d = labels + ret_list.append(temp_instances) + + return ret_list + + @torch.no_grad() + def post_processing( + self, + img_metas, + batch_box_preds, + batch_score, + batch_label, + test_cfg, + post_center_range, + task_id, + batch_mask, + batch_iou, + ): + batch_size = len(batch_score) + + prediction_dicts = [] + for i in range(batch_size): + box_preds = batch_box_preds[i] + scores = batch_score[i] + labels = batch_label[i] + mask = batch_mask[i] + + distance_mask = (box_preds[..., :3] >= post_center_range[:3]).all( + 1) & (box_preds[..., :3] <= post_center_range[3:]).all(1) + + mask = mask & distance_mask + + box_preds = box_preds[mask] + scores = scores[mask] + labels = labels[mask] + + if self.use_iou_loss: + iou_factor = torch.LongTensor(self.iou_factor).to(labels) + ious = batch_iou[i][mask] + ious = torch.pow(ious, iou_factor[labels]) + scores = scores * ious + + boxes_for_nms = box_preds[:, [0, 1, 2, 3, 4, 5, -1]] + + if test_cfg.get('circular_nms', False): + centers = boxes_for_nms[:, [0, 1]] + boxes = torch.cat([centers, scores.view(-1, 1)], dim=1) + selected = _circle_nms( + boxes, + min_radius=test_cfg.min_radius[task_id], + post_max_size=test_cfg.nms.nms_post_max_size, + ) + elif test_cfg.nms.get('use_multi_class_nms', False): + # multi class nms + selected = [] + for c in range(3): + class_mask = labels == c + if class_mask.sum() > 0: + class_idx = class_mask.nonzero() + # boxes_for_nms = xywhr2xyxyr( + # img_metas[i]['box_type_3d']( + # box_preds[:, :], self.bbox_code_size).bev) + # select = nms_bev( + # boxes_for_nms[class_mask].float(), + # scores[class_mask].float(), + # thresh=test_cfg.nms.nms_iou_threshold[c], + # pre_max_size=test_cfg.nms.nms_pre_max_size[c], + # post_max_size=test_cfg.nms.nms_post_max_size[c], + # ) + select = rotate_nms_pcdet( + boxes_for_nms[class_mask].float(), + scores[class_mask].float(), + thresh=test_cfg.nms.nms_iou_threshold[c], + pre_maxsize=test_cfg.nms.nms_pre_max_size[c], + post_max_size=test_cfg.nms.nms_post_max_size[c], + ) + selected.append(class_idx[select, 0]) + if len(selected) > 0: + selected = torch.cat(selected, dim=0) + else: + selected = nms_bev( + boxes_for_nms.float(), + scores.float(), + thresh=test_cfg.nms.nms_iou_threshold, + pre_max_size=test_cfg.nms.nms_pre_max_size, + post_max_size=test_cfg.nms.nms_post_max_size, + ) + + selected_boxes = box_preds[selected] + selected_scores = scores[selected] + selected_labels = labels[selected] + + prediction_dict = { + 'bboxes': selected_boxes, + 'scores': selected_scores, + 'labels': selected_labels, + } + + prediction_dicts.append(prediction_dict) + + return prediction_dicts + + +def _circle_nms(boxes, min_radius, post_max_size=83): + """NMS according to center distance.""" + keep = np.array(circle_nms(boxes.cpu().numpy(), + thresh=min_radius))[:post_max_size] + + keep = torch.from_numpy(keep).long().to(boxes.device) + + return keep + + +def get_box(pred_boxs, order, test_cfg, H, W): + batch = pred_boxs.shape[0] + obj_num = order.shape[1] + ys, xs = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)]) + ys = ys.view(1, H, W).repeat(batch, 1, 1).to(pred_boxs) + xs = xs.view(1, H, W).repeat(batch, 1, 1).to(pred_boxs) + + batch_id = np.indices((batch, obj_num))[0] + batch_id = torch.from_numpy(batch_id).to(order) + xs = xs.view(batch, H * W)[batch_id, order].unsqueeze(1) + pred_boxs[:, + 0:1] + ys = ys.view(batch, H * W)[batch_id, order].unsqueeze(1) + pred_boxs[:, + 1:2] + + xs = xs * test_cfg.out_size_factor * test_cfg.voxel_size[ + 0] + test_cfg.pc_range[0] + ys = ys * test_cfg.out_size_factor * test_cfg.voxel_size[ + 1] + test_cfg.pc_range[1] + + rot = torch.atan2(pred_boxs[:, 6:7], pred_boxs[:, 7:8]) + pred = torch.cat( + [xs, ys, pred_boxs[:, 2:3], + torch.exp(pred_boxs[:, 3:6]), rot], dim=1) + pred[:, 2] = pred[:, 2] - pred[:, 5] / 2 + + return torch.transpose(pred, 1, 2).contiguous() # B M 7 + + +def get_box_gt(gt_boxs, order, test_cfg, H, W): + batch = gt_boxs.shape[0] + obj_num = order.shape[1] + ys, xs = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)]) + ys = ys.view(1, H, W).repeat(batch, 1, 1).to(gt_boxs) + xs = xs.view(1, H, W).repeat(batch, 1, 1).to(gt_boxs) + + batch_id = np.indices((batch, obj_num))[0] + batch_id = torch.from_numpy(batch_id).to(order) + + batch_gt_dim = torch.exp(gt_boxs[..., 3:6]) + batch_gt_hei = gt_boxs[..., 2:3] + batch_gt_rot = torch.atan2(gt_boxs[..., -2:-1], gt_boxs[..., -1:]) + xs = xs.view(batch, H * W)[batch_id, order].unsqueeze(2) + gt_boxs[..., + 0:1] + ys = ys.view(batch, H * W)[batch_id, order].unsqueeze(2) + gt_boxs[..., + 1:2] + + xs = xs * test_cfg.out_size_factor * test_cfg.voxel_size[ + 0] + test_cfg.pc_range[0] + ys = ys * test_cfg.out_size_factor * test_cfg.voxel_size[ + 1] + test_cfg.pc_range[1] + + batch_box_targets = torch.cat( + [xs, ys, batch_gt_hei, batch_gt_dim, batch_gt_rot], dim=-1) + + batch_box_targets[..., + 2] = batch_box_targets[..., + 2] - batch_box_targets[..., 5] / 2 + + return batch_box_targets # B M 7 + + +def get_corresponding_box(x_ind, y_ind, y_mask, y_cls, target_box): + # find the id in y which has the same ind in x + select_target = torch.zeros(x_ind.shape[0], x_ind.shape[1], + target_box.shape[2]).to(target_box) + select_mask = torch.zeros_like(x_ind).to(y_mask) + select_cls = torch.zeros_like(x_ind).to(y_cls) + + for i in range(x_ind.shape[0]): + idx = torch.arange(y_ind[i].shape[-1]).to(x_ind) + idx = idx[y_mask[i]] + box_cls = y_cls[i][y_mask[i]] + valid_y_ind = y_ind[i][y_mask[i]] + match = (x_ind[i].unsqueeze(1) == valid_y_ind.unsqueeze(0)).nonzero() + select_target[i, match[:, 0]] = target_box[i, idx[match[:, 1]]] + select_mask[i, match[:, 0]] = 1 + select_cls[i, match[:, 0]] = box_cls[match[:, 1]] + + return select_target, select_mask, select_cls diff --git a/projects/centerformer/centerformer/losses.py b/projects/centerformer/centerformer/losses.py new file mode 100644 index 0000000000..daa93c6a2d --- /dev/null +++ b/projects/centerformer/centerformer/losses.py @@ -0,0 +1,56 @@ +import torch +from torch import nn + +from mmdet3d.registry import MODELS + + +def _gather_feat(feat, ind, mask=None): + dim = feat.size(2) + ind = ind.unsqueeze(2).expand(ind.size(0), ind.size(1), dim) + feat = feat.gather(1, ind) + if mask is not None: + mask = mask.unsqueeze(2).expand_as(feat) + feat = feat[mask] + feat = feat.view(-1, dim) + return feat + + +def _transpose_and_gather_feat(feat, ind): + feat = feat.permute(0, 2, 3, 1).contiguous() + feat = feat.view(feat.size(0), -1, feat.size(3)) + feat = _gather_feat(feat, ind) + return feat + + +@MODELS.register_module() +class FastFocalLoss(nn.Module): + """Reimplemented focal loss, exactly the same as the CornerNet version. + + Faster and costs much less memory. + """ + + def __init__(self, focal_factor=2): + super(FastFocalLoss, self).__init__() + self.focal_factor = focal_factor + + def forward(self, out, target, ind, mask, cat): + ''' + Arguments: + out, target: B x C x H x W + ind, mask: B x M + cat (category id for peaks): B x M + ''' + mask = mask.float() + gt = torch.pow(1 - target, 4) + neg_loss = torch.log(1 - out) * torch.pow(out, self.focal_factor) * gt + neg_loss = neg_loss.sum() + + pos_pred_pix = _transpose_and_gather_feat(out, ind) # B x M x C + pos_pred = pos_pred_pix.gather(2, cat.unsqueeze(2)) # B x M + num_pos = mask.sum() + pos_loss = torch.log(pos_pred) * torch.pow( + 1 - pos_pred, self.focal_factor) * mask.unsqueeze(2) + pos_loss = pos_loss.sum() + if num_pos == 0: + return -neg_loss + return -(pos_loss + neg_loss) / num_pos diff --git a/projects/centerformer/centerformer/rpn_transformer.py b/projects/centerformer/centerformer/rpn_transformer.py new file mode 100644 index 0000000000..7379e47f96 --- /dev/null +++ b/projects/centerformer/centerformer/rpn_transformer.py @@ -0,0 +1,975 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from mmcv.cnn import build_norm_layer +from mmdet.models.utils import multi_apply +from mmengine.logging import print_log +from mmengine.structures import InstanceData +from torch import Tensor, nn +from torch.nn import functional as F + +from mmdet3d.models.utils import draw_heatmap_gaussian, gaussian_radius +from mmdet3d.registry import MODELS +from mmdet3d.structures import center_to_corner_box2d +from .utils import Deform_Transformer + + +class ChannelAttention(nn.Module): + + def __init__(self, in_planes, ratio=16): + super(ChannelAttention, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + + self.fc = nn.Sequential( + nn.Conv2d(in_planes, in_planes // 16, 1, bias=False), + nn.ReLU(), + nn.Conv2d(in_planes // 16, in_planes, 1, bias=False), + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = self.fc(self.avg_pool(x)) + max_out = self.fc(self.max_pool(x)) + out = avg_out + max_out + return self.sigmoid(out) * x + + +class SpatialAttention(nn.Module): + + def __init__(self, kernel_size=7): + super(SpatialAttention, self).__init__() + + self.conv1 = nn.Conv2d( + 2, 1, kernel_size, padding=kernel_size // 2, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + y = torch.cat([avg_out, max_out], dim=1) + y = self.conv1(y) + return self.sigmoid(y) * x + + +class SpatialAttention_mtf(nn.Module): + + def __init__(self, kernel_size=7): + super(SpatialAttention_mtf, self).__init__() + + self.conv1 = nn.Conv2d( + 2, 1, kernel_size, padding=kernel_size // 2, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, curr, prev): + avg_out = torch.mean(curr, dim=1, keepdim=True) + max_out, _ = torch.max(curr, dim=1, keepdim=True) + y = torch.cat([avg_out, max_out], dim=1) + y = self.conv1(y) + return self.sigmoid(y) * prev + + +class RPN_transformer_base(nn.Module): + + def __init__( + self, + layer_nums, # [2,2,2] + ds_num_filters, # [128,256,64] + num_input_features, # 256 + transformer_config=None, + hm_head_layer=2, + corner_head_layer=2, + corner=False, + assign_label_window_size=1, + classes=3, + use_gt_training=False, + norm_cfg=None, + logger=None, + init_bias=-2.19, + score_threshold=0.1, + obj_num=500, + **kwargs): + super(RPN_transformer_base, self).__init__() + self._layer_strides = [1, 2, -4] + self._num_filters = ds_num_filters + self._layer_nums = layer_nums + self._num_input_features = num_input_features + self.score_threshold = score_threshold + self.transformer_config = transformer_config + self.corner = corner + self.obj_num = obj_num + self.use_gt_training = use_gt_training + self.window_size = assign_label_window_size**2 + self.cross_attention_kernel_size = [3, 3, 3] + self.batch_id = None + + if norm_cfg is None: + norm_cfg = dict(type='BN', eps=1e-3, momentum=0.01) + self._norm_cfg = norm_cfg + + assert len(self._layer_strides) == len(self._layer_nums) + assert len(self._num_filters) == len(self._layer_nums) + assert self.transformer_config is not None + + in_filters = [ + self._num_input_features, + self._num_filters[0], + self._num_filters[1], + ] + blocks = [] + + for i, layer_num in enumerate(self._layer_nums): + block, num_out_filters = self._make_layer( + in_filters[i], + self._num_filters[i], + layer_num, + stride=self._layer_strides[i], + ) + blocks.append(block) + self.blocks = nn.ModuleList(blocks) + self.up = nn.Sequential( + nn.ConvTranspose2d( + self._num_filters[0], + self._num_filters[2], + 2, + stride=2, + bias=False), + build_norm_layer(self._norm_cfg, self._num_filters[2])[1], + nn.ReLU()) + # heatmap prediction + hm_head = [] + for i in range(hm_head_layer - 1): + hm_head.append( + nn.Conv2d( + self._num_filters[-1] * 2, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=True, + )) + hm_head.append(build_norm_layer(self._norm_cfg, 64)[1]) + hm_head.append(nn.ReLU()) + + hm_head.append( + nn.Conv2d( + 64, classes, kernel_size=3, stride=1, padding=1, bias=True)) + hm_head[-1].bias.data.fill_(init_bias) + self.hm_head = nn.Sequential(*hm_head) + + if self.corner: + self.corner_head = [] + for i in range(corner_head_layer - 1): + self.corner_head.append( + nn.Conv2d( + self._num_filters[-1] * 2, + 64, + kernel_size=3, + stride=1, + padding=1, + bias=True, + )) + self.corner_head.append( + build_norm_layer(self._norm_cfg, 64)[1]) + self.corner_head.append(nn.ReLU()) + + self.corner_head.append( + nn.Conv2d( + 64, 1, kernel_size=3, stride=1, padding=1, bias=True)) + self.corner_head[-1].bias.data.fill_(init_bias) + self.corner_head = nn.Sequential(*self.corner_head) + + def _make_layer(self, inplanes, planes, num_blocks, stride=1): + + if stride > 0: + block = [ + nn.ZeroPad2d(1), + nn.Conv2d(inplanes, planes, 3, stride=stride, bias=False), + build_norm_layer(self._norm_cfg, planes)[1], + nn.ReLU(), + ] + else: + block = [ + nn.ConvTranspose2d( + inplanes, planes, -stride, stride=-stride, bias=False), + build_norm_layer(self._norm_cfg, planes)[1], + nn.ReLU(), + ] + + for j in range(num_blocks): + block.append(nn.Conv2d(planes, planes, 3, padding=1, bias=False)) + block.append(build_norm_layer(self._norm_cfg, planes)[1], ) + block.append(nn.ReLU()) + + block.append(ChannelAttention(planes)) + block.append(SpatialAttention()) + block = nn.Sequential(*block) + + return block, planes + + # default init_weights for conv(msra) and norm in ConvModule + # def init_weights(self): + # for m in self.modules(): + # if isinstance(m, nn.Conv2d): + # xavier_init(m, distribution='uniform') + + def forward(self, x, example=None): + pass + + def get_multi_scale_feature(self, center_pos, feats): + """ + Args: + center_pos: center coor at the lowest scale feature map [B 500 2] + feats: multi scale BEV feature 3*[B C H W] + Returns: + neighbor_feat: [B 500 K C] + neighbor_pos: [B 500 K 2] + """ + kernel_size = self.cross_attention_kernel_size + batch, num_cls, H, W = feats[0].size() + + center_num = center_pos.shape[1] + + relative_pos_list = [] + neighbor_feat_list = [] + for i, k in enumerate(kernel_size): + neighbor_coords = torch.arange(-(k // 2), (k // 2) + 1) + neighbor_coords = torch.flatten( + torch.stack( + torch.meshgrid([neighbor_coords, neighbor_coords]), dim=0), + 1, + ) # [2, k] + neighbor_coords = (neighbor_coords.permute( + 1, + 0).contiguous().to(center_pos)) # relative coordinate [k, 2] + neighbor_coords = (center_pos[:, :, None, :] // (2**i) + + neighbor_coords[None, None, :, :] + ) # coordinates [B, 500, k, 2] + neighbor_coords = torch.clamp( + neighbor_coords, min=0, + max=H // (2**i) - 1) # prevent out of bound + feat_id = (neighbor_coords[:, :, :, 1] * (W // (2**i)) + + neighbor_coords[:, :, :, 0]) # pixel id [B, 500, k] + feat_id = feat_id.reshape(batch, -1) # pixel id [B, 500*k] + # selected_feat = torch.gather(feats[i].reshape(batch, num_cls,(H*W)//(4**i)).permute(0, 2, 1).contiguous(),1,feat_id) + selected_feat = ( + feats[i].reshape(batch, num_cls, (H * W) // (4**i)).permute( + 0, 2, 1).contiguous()[self.batch_id.repeat(1, k**2), + feat_id]) # B, 500*k, C + neighbor_feat_list.append( + selected_feat.reshape(batch, center_num, -1, + num_cls)) # B, 500, k, C + relative_pos_list.append(neighbor_coords * (2**i)) # B, 500, k, 2 + # relative_pos_list.append(F.pad(neighbor_coords*(2**i), (0,1), "constant", i)) # B, 500, k, 3 + + neighbor_pos = torch.cat(relative_pos_list, dim=2) # B, 500, K, 2/3 + neighbor_feats = torch.cat(neighbor_feat_list, dim=2) # B, 500, K, C + return neighbor_feats, neighbor_pos + + def get_multi_scale_feature_multiframe(self, center_pos, feats, timeframe): + """ + Args: + center_pos: center coor at the lowest scale feature map [B 500 2] + feats: multi scale BEV feature (3+k)*[B C H W] + timeframe: timeframe [B,k] + Returns: + neighbor_feat: [B 500 K C] + neighbor_pos: [B 500 K 2] + neighbor_time: [B 500 K 1] + """ + kernel_size = self.cross_attention_kernel_size + batch, num_cls, H, W = feats[0].size() + + center_num = center_pos.shape[1] + + relative_pos_list = [] + neighbor_feat_list = [] + timeframe_list = [] + for i, k in enumerate(kernel_size): + neighbor_coords = torch.arange(-(k // 2), (k // 2) + 1) + neighbor_coords = torch.flatten( + torch.stack( + torch.meshgrid([neighbor_coords, neighbor_coords]), dim=0), + 1, + ) # [2, k] + neighbor_coords = (neighbor_coords.permute( + 1, + 0).contiguous().to(center_pos)) # relative coordinate [k, 2] + neighbor_coords = (center_pos[:, :, None, :] // (2**i) + + neighbor_coords[None, None, :, :] + ) # coordinates [B, 500, k, 2] + neighbor_coords = torch.clamp( + neighbor_coords, min=0, + max=H // (2**i) - 1) # prevent out of bound + feat_id = (neighbor_coords[:, :, :, 1] * (W // (2**i)) + + neighbor_coords[:, :, :, 0]) # pixel id [B, 500, k] + feat_id = feat_id.reshape(batch, -1) # pixel id [B, 500*k] + selected_feat = ( + feats[i].reshape(batch, num_cls, (H * W) // (4**i)).permute( + 0, 2, 1).contiguous()[self.batch_id.repeat(1, k**2), + feat_id]) # B, 500*k, C + neighbor_feat_list.append( + selected_feat.reshape(batch, center_num, -1, + num_cls)) # B, 500, k, C + relative_pos_list.append(neighbor_coords * (2**i)) # B, 500, k, 2 + timeframe_list.append( + torch.full_like(neighbor_coords[:, :, :, 0:1], 0)) # B, 500, k + if i == 0: + # add previous frame feature + for frame_num in range(feats[-1].shape[1]): + selected_feat = (feats[-1][:, frame_num, :, :, :].reshape( + batch, num_cls, (H * W) // (4**i)).permute( + 0, 2, + 1).contiguous()[self.batch_id.repeat(1, k**2), + feat_id]) # B, 500*k, C + neighbor_feat_list.append( + selected_feat.reshape(batch, center_num, -1, num_cls)) + relative_pos_list.append(neighbor_coords * (2**i)) + time = timeframe[:, frame_num + 1].to(selected_feat) # B + timeframe_list.append( + time[:, None, None, None] * torch.full_like( + neighbor_coords[:, :, :, 0:1], 1)) # B, 500, k + + neighbor_pos = torch.cat(relative_pos_list, dim=2) # B, 500, K, 2/3 + neighbor_feats = torch.cat(neighbor_feat_list, dim=2) # B, 500, K, C + neighbor_time = torch.cat(timeframe_list, dim=2) # B, 500, K, 1 + + return neighbor_feats, neighbor_pos, neighbor_time + + +@MODELS.register_module() +class RPN_transformer_deformable(RPN_transformer_base): + + def __init__( + self, + layer_nums, # [2,2,2] + ds_num_filters, # [128,256,64] + num_input_features, + tasks=dict(), + transformer_config=None, + hm_head_layer=2, + corner_head_layer=2, + corner=False, + parametric_embedding=False, + assign_label_window_size=1, + classes=3, + use_gt_training=False, + norm_cfg=None, + logger=None, + init_bias=-2.19, + score_threshold=0.1, + obj_num=500, + train_cfg=None, + test_cfg=None, + **kwargs): + super(RPN_transformer_deformable, self).__init__( + layer_nums, + ds_num_filters, + num_input_features, + transformer_config, + hm_head_layer, + corner_head_layer, + corner, + assign_label_window_size, + classes, + use_gt_training, + norm_cfg, + logger, + init_bias, + score_threshold, + obj_num, + ) + self.train_cfg = train_cfg + self.test_cfg = test_cfg + self.tasks = tasks + num_classes = [len(t['class_names']) for t in tasks] + self.class_names = [t['class_names'] for t in tasks] + + self.transformer_layer = Deform_Transformer( + self._num_filters[-1] * 2, + depth=transformer_config.depth, + heads=transformer_config.heads, + dim_head=transformer_config.dim_head, + mlp_dim=transformer_config.MLP_dim, + dropout=transformer_config.DP_rate, + out_attention=transformer_config.out_att, + n_points=transformer_config.get('n_points', 9), + ) + self.pos_embedding_type = transformer_config.get( + 'pos_embedding_type', 'linear') + if self.pos_embedding_type == 'linear': + self.pos_embedding = nn.Linear(2, self._num_filters[-1] * 2) + else: + raise NotImplementedError() + self.parametric_embedding = parametric_embedding + if self.parametric_embedding: + self.query_embed = nn.Embedding(self.obj_num, + self._num_filters[-1] * 2) + nn.init.uniform_(self.query_embed.weight, -1.0, 1.0) + + print_log('Finish RPN_transformer_deformable Initialization', + 'current') + + def _sigmoid(self, x): + y = torch.clamp(x.sigmoid_(), min=1e-4, max=1 - 1e-4) + return y + + def forward(self, x, batch_data_samples): + + batch_gt_instance_3d = [] + for data_sample in batch_data_samples: + batch_gt_instance_3d.append(data_sample.gt_instances_3d) + + # FPN + x = self.blocks[0](x) + x_down = self.blocks[1](x) + x_up = torch.cat([self.blocks[2](x_down), self.up(x)], dim=1) + + # heatmap head + hm = self.hm_head(x_up) + + if self.corner and self.corner_head.training: + corner_hm = self.corner_head(x_up) + corner_hm = torch.sigmoid(corner_hm) + + # find top K center location + hm = torch.sigmoid(hm) + batch, num_cls, H, W = hm.size() + + scores, labels = torch.max( + hm.reshape(batch, num_cls, H * W), dim=1) # b,H*W + self.batch_id = torch.from_numpy(np.indices( + (batch, self.obj_num))[0]).to(labels) + + if self.training: + heatmaps, anno_boxes, gt_inds, gt_masks, corner_heatmaps, cat_labels = self.get_targets( # noqa: E501 + batch_gt_instance_3d) + batch_targets = dict( + ind=gt_inds, + mask=gt_masks, + hm=heatmaps, + anno_box=anno_boxes, + corners=corner_heatmaps, + cat=cat_labels) + inds = gt_inds[0][:, (self.window_size // 2)::self.window_size] + masks = gt_masks[0][:, (self.window_size // 2)::self.window_size] + batch_id_gt = torch.from_numpy( + np.indices((batch, inds.shape[1]))[0]).to(labels) + scores[batch_id_gt, inds] = scores[batch_id_gt, inds] + masks + order = scores.sort(1, descending=True)[1] + order = order[:, :self.obj_num] + scores[batch_id_gt, inds] = scores[batch_id_gt, inds] - masks + else: + order = scores.sort(1, descending=True)[1] + order = order[:, :self.obj_num] + batch_targets = None + + scores = torch.gather(scores, 1, order) + labels = torch.gather(labels, 1, order) + mask = scores > self.score_threshold + + ct_feat = (x_up.reshape(batch, -1, + H * W).transpose(2, + 1).contiguous()[self.batch_id, + order] + ) # B, 500, C + + # create position embedding for each center + y_coor = order // W + x_coor = order - y_coor * W + y_coor, x_coor = y_coor.to(ct_feat), x_coor.to(ct_feat) + y_coor, x_coor = y_coor / H, x_coor / W + pos_features = torch.stack([x_coor, y_coor], dim=2) + + if self.parametric_embedding: + ct_feat = self.query_embed.weight + ct_feat = ct_feat.unsqueeze(0).expand(batch, -1, -1) + + # run transformer + src = torch.cat( + ( + x_up.reshape(batch, -1, H * W).transpose(2, 1).contiguous(), + x.reshape(batch, -1, + (H * W) // 4).transpose(2, 1).contiguous(), + x_down.reshape(batch, -1, + (H * W) // 16).transpose(2, 1).contiguous(), + ), + dim=1, + ) # B ,sum(H*W), C + spatial_shapes = torch.as_tensor( + [(H, W), (H // 2, W // 2), (H // 4, W // 4)], + dtype=torch.long, + device=ct_feat.device, + ) + level_start_index = torch.cat(( + spatial_shapes.new_zeros((1, )), + spatial_shapes.prod(1).cumsum(0)[:-1], + )) + + transformer_out = self.transformer_layer( + ct_feat, + self.pos_embedding, + src, + spatial_shapes, + level_start_index, + center_pos=pos_features, + ) # (B,N,C) + + ct_feat = (transformer_out['ct_feat'].transpose(2, 1).contiguous() + ) # B, C, 500 + + out_dict = { + 'hm': hm, + 'scores': scores, + 'labels': labels, + 'order': order, + 'ct_feat': ct_feat, + 'mask': mask, + } + if 'out_attention' in transformer_out: + out_dict.update( + {'out_attention': transformer_out['out_attention']}) + if self.corner and self.corner_head.training: + out_dict.update({'corner_hm': corner_hm}) + + return out_dict, batch_targets + + def get_targets( + self, + batch_gt_instances_3d: List[InstanceData], + ) -> Tuple[List[Tensor]]: + """Generate targets. How each output is transformed: Each nested list + is transposed so that all same-index elements in each sub-list (1, ..., + N) become the new sub-lists. + + [ [a0, a1, a2, ... ], [b0, b1, b2, ... ], ... ] + ==> [ [a0, b0, ... ], [a1, b1, ... ], [a2, b2, ... ] ] + The new transposed nested list is converted into a list of N + tensors generated by concatenating tensors in the new sub-lists. + [ tensor0, tensor1, tensor2, ... ] + Args: + batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of + gt_instances. It usually includes ``bboxes_3d`` and\ + ``labels_3d`` attributes. + Returns: + Returns: + tuple[list[torch.Tensor]]: Tuple of target including + the following results in order. + - list[torch.Tensor]: Heatmap scores. + - list[torch.Tensor]: Ground truth boxes. + - list[torch.Tensor]: Indexes indicating the + position of the valid boxes. + - list[torch.Tensor]: Masks indicating which + boxes are valid. + """ + heatmaps, anno_boxes, inds, masks, corner_heatmaps, cat_labels = multi_apply( + self.get_targets_single, batch_gt_instances_3d) + # Transpose heatmaps + heatmaps = list(map(list, zip(*heatmaps))) + heatmaps = [torch.stack(hms_) for hms_ in heatmaps] + # Transpose heatmaps + corner_heatmaps = list(map(list, zip(*corner_heatmaps))) + corner_heatmaps = [torch.stack(hms_) for hms_ in corner_heatmaps] + # Transpose anno_boxes + anno_boxes = list(map(list, zip(*anno_boxes))) + anno_boxes = [torch.stack(anno_boxes_) for anno_boxes_ in anno_boxes] + # Transpose inds + inds = list(map(list, zip(*inds))) + inds = [torch.stack(inds_) for inds_ in inds] + # Transpose inds + masks = list(map(list, zip(*masks))) + masks = [torch.stack(masks_) for masks_ in masks] + # Transpose cat_labels + cat_labels = list(map(list, zip(*cat_labels))) + cat_labels = [torch.stack(labels_) for labels_ in cat_labels] + return heatmaps, anno_boxes, inds, masks, corner_heatmaps, cat_labels + + def get_targets_single(self, + gt_instances_3d: InstanceData) -> Tuple[Tensor]: + """Generate training targets for a single sample. + Args: + gt_instances_3d (:obj:`InstanceData`): Gt_instances of + single data sample. It usually includes + ``bboxes_3d`` and ``labels_3d`` attributes. + Returns: + tuple[list[torch.Tensor]]: Tuple of target including + the following results in order. + - list[torch.Tensor]: Heatmap scores. + - list[torch.Tensor]: Ground truth boxes. + - list[torch.Tensor]: Indexes indicating the position + of the valid boxes. + - list[torch.Tensor]: Masks indicating which boxes + are valid. + """ + gt_labels_3d = gt_instances_3d.labels_3d + gt_bboxes_3d = gt_instances_3d.bboxes_3d + device = gt_labels_3d.device + gt_bboxes_3d = torch.cat( + (gt_bboxes_3d.gravity_center, gt_bboxes_3d.tensor[:, 3:]), + dim=1).to(device) + max_objs = self.train_cfg['max_objs'] * self.train_cfg['dense_reg'] + grid_size = torch.tensor(self.train_cfg['grid_size']) + pc_range = torch.tensor(self.train_cfg['point_cloud_range']) + voxel_size = torch.tensor(self.train_cfg['voxel_size']) + + feature_map_size = grid_size[:2] // self.train_cfg['out_size_factor'] + + # reorganize the gt_dict by tasks + task_masks = [] + flag = 0 + for class_name in self.class_names: + task_masks.append([ + torch.where(gt_labels_3d == class_name.index(i) + flag) + for i in class_name + ]) + flag += len(class_name) + + task_boxes = [] + task_classes = [] + flag2 = 0 + for idx, mask in enumerate(task_masks): + task_box = [] + task_class = [] + for m in mask: + task_box.append(gt_bboxes_3d[m]) + # 0 is background for each task, so we need to add 1 here. + task_class.append(gt_labels_3d[m] + 1 - flag2) + task_boxes.append(torch.cat(task_box, axis=0).to(device)) + task_classes.append(torch.cat(task_class).long().to(device)) + flag2 += len(mask) + draw_gaussian = draw_heatmap_gaussian + heatmaps, anno_boxes, inds, masks, corner_heatmaps, cat_labels = [], [], [], [], [], [] + + for idx in range(len(self.tasks)): + heatmap = gt_bboxes_3d.new_zeros( + (len(self.class_names[idx]), feature_map_size[1], + feature_map_size[0])) + corner_heatmap = torch.zeros( + (1, feature_map_size[1], feature_map_size[0]), + dtype=torch.float32, + device=device) + + anno_box = gt_bboxes_3d.new_zeros((max_objs, 8), + dtype=torch.float32) + + ind = gt_labels_3d.new_zeros((max_objs), dtype=torch.int64) + mask = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.uint8) + cat_label = gt_bboxes_3d.new_zeros((max_objs), dtype=torch.int64) + + num_objs = min(task_boxes[idx].shape[0], max_objs) + + for k in range(num_objs): + cls_id = task_classes[idx][k] - 1 + + width = task_boxes[idx][k][3] + length = task_boxes[idx][k][4] + width = width / voxel_size[0] / self.train_cfg[ + 'out_size_factor'] + length = length / voxel_size[1] / self.train_cfg[ + 'out_size_factor'] + + if width > 0 and length > 0: + radius = gaussian_radius( + (length, width), + min_overlap=self.train_cfg['gaussian_overlap']) + radius = max(self.train_cfg['min_radius'], int(radius)) + + # be really careful for the coordinate system of + # your box annotation. + x, y, z = task_boxes[idx][k][0], task_boxes[idx][k][ + 1], task_boxes[idx][k][2] + + coor_x = ( + x - pc_range[0] + ) / voxel_size[0] / self.train_cfg['out_size_factor'] + coor_y = ( + y - pc_range[1] + ) / voxel_size[1] / self.train_cfg['out_size_factor'] + + center = torch.tensor([coor_x, coor_y], + dtype=torch.float32, + device=device) + center_int = center.to(torch.int32) + + # throw out not in range objects to avoid out of array + # area when creating the heatmap + if not (0 <= center_int[0] < feature_map_size[0] + and 0 <= center_int[1] < feature_map_size[1]): + continue + + draw_gaussian(heatmap[cls_id], center_int, radius) + + radius = radius // 2 + # # draw four corner and center TODO: use torch + rot = task_boxes[idx][k][6] + corner_keypoints = center_to_corner_box2d( + center.unsqueeze(0).cpu().numpy(), + torch.tensor([[width, length]], + dtype=torch.float32).numpy(), + angles=rot, + origin=0.5) + corner_keypoints = torch.from_numpy(corner_keypoints).to( + center) + + draw_gaussian(corner_heatmap[0], center_int, radius) + draw_gaussian( + corner_heatmap[0], + (corner_keypoints[0, 0] + corner_keypoints[0, 1]) / 2, + radius) + draw_gaussian( + corner_heatmap[0], + (corner_keypoints[0, 2] + corner_keypoints[0, 3]) / 2, + radius) + draw_gaussian( + corner_heatmap[0], + (corner_keypoints[0, 0] + corner_keypoints[0, 3]) / 2, + radius) + draw_gaussian( + corner_heatmap[0], + (corner_keypoints[0, 1] + corner_keypoints[0, 2]) / 2, + radius) + + new_idx = k + x, y = center_int[0], center_int[1] + + assert (y * feature_map_size[0] + x < + feature_map_size[0] * feature_map_size[1]) + + ind[new_idx] = y * feature_map_size[0] + x + mask[new_idx] = 1 + cat_label[new_idx] = cls_id + # TODO: support other outdoor dataset + # vx, vy = task_boxes[idx][k][7:] + rot = task_boxes[idx][k][6] + box_dim = task_boxes[idx][k][3:6] + box_dim = box_dim.log() + anno_box[new_idx] = torch.cat([ + center - torch.tensor([x, y], device=device), + z.unsqueeze(0), box_dim, + torch.sin(rot).unsqueeze(0), + torch.cos(rot).unsqueeze(0) + ]) + + heatmaps.append(heatmap) + corner_heatmaps.append(corner_heatmap) + anno_boxes.append(anno_box) + masks.append(mask) + inds.append(ind) + cat_labels.append(cat_label) + return heatmaps, anno_boxes, inds, masks, corner_heatmaps, cat_labels + + +@MODELS.register_module() +class RPN_transformer_deformable_mtf(RPN_transformer_base): + + def __init__( + self, + layer_nums, # [2,2,2] + ds_num_filters, # [128,256,64] + num_input_features, # 256 + transformer_config=None, + hm_head_layer=2, + corner_head_layer=2, + corner=False, + parametric_embedding=False, + assign_label_window_size=1, + classes=3, + use_gt_training=False, + norm_cfg=None, + logger=None, + init_bias=-2.19, + score_threshold=0.1, + obj_num=500, + frame=1, + **kwargs): + super(RPN_transformer_deformable_mtf, self).__init__( + layer_nums, + ds_num_filters, + num_input_features, + transformer_config, + hm_head_layer, + corner_head_layer, + corner, + assign_label_window_size, + classes, + use_gt_training, + norm_cfg, + logger, + init_bias, + score_threshold, + obj_num, + ) + self.frame = frame + + self.out = nn.Sequential( + nn.Conv2d( + self._num_filters[0] * frame, + self._num_filters[0], + 3, + padding=1, + bias=False, + ), + build_norm_layer(self._norm_cfg, self._num_filters[0])[1], + nn.ReLU(), + ) + self.mtf_attention = SpatialAttention_mtf() + self.time_embedding = nn.Linear(1, self._num_filters[0]) + + self.transformer_layer = Deform_Transformer( + self._num_filters[-1] * 2, + depth=transformer_config.depth, + heads=transformer_config.heads, + levels=2 + self.frame, + dim_head=transformer_config.dim_head, + mlp_dim=transformer_config.MLP_dim, + dropout=transformer_config.DP_rate, + out_attention=transformer_config.out_att, + n_points=transformer_config.get('n_points', 9), + ) + self.pos_embedding_type = transformer_config.get( + 'pos_embedding_type', 'linear') + if self.pos_embedding_type == 'linear': + self.pos_embedding = nn.Linear(2, self._num_filters[-1] * 2) + else: + raise NotImplementedError() + self.parametric_embedding = parametric_embedding + if self.parametric_embedding: + self.query_embed = nn.Embedding(self.obj_num, + self._num_filters[-1] * 2) + nn.init.uniform_(self.query_embed.weight, -1.0, 1.0) + + print_log('Finish RPN_transformer_deformable Initialization', + 'current') + + def forward(self, x, example=None): + + # FPN + x = self.blocks[0](x) + x_down = self.blocks[1](x) + x_up = torch.cat([self.blocks[2](x_down), self.up(x)], dim=1) + + # take out the BEV feature on current frame + x = torch.split(x, self.frame) + x_up = torch.split(x_up, self.frame) + x_down = torch.split(x_down, self.frame) + x_prev = torch.stack([t[1:] for t in x_up], dim=0) # B,K,C,H,W + x = torch.stack([t[0] for t in x], dim=0) + x_down = torch.stack([t[0] for t in x_down], dim=0) + + x_up = torch.stack([t[0] for t in x_up], dim=0) # B,C,H,W + # use spatial attention in current frame on previous feature + x_prev_cat = self.mtf_attention( + x_up, + x_prev.reshape(x_up.shape[0], -1, x_up.shape[2], + x_up.shape[3])) # B,K*C,H,W + # time embedding + x_up_fuse = torch.cat((x_up, x_prev_cat), dim=1) + self.time_embedding( + example['times'][:, :, None].to(x_up)).reshape( + x_up.shape[0], -1, 1, 1) + # fuse mtf feature + x_up_fuse = self.out(x_up_fuse) + + # heatmap head + hm = self.hm_head(x_up_fuse) + + if self.corner and self.corner_head.training: + corner_hm = self.corner_head(x_up_fuse) + corner_hm = torch.sigmoid(corner_hm) + + # find top K center location + hm = torch.sigmoid(hm) + batch, num_cls, H, W = hm.size() + + scores, labels = torch.max( + hm.reshape(batch, num_cls, H * W), dim=1) # b,H*W + self.batch_id = torch.from_numpy(np.indices( + (batch, self.obj_num))[0]).to(labels) + + if self.use_gt_training and self.hm_head.training: + gt_inds = example['ind'][0][:, (self.window_size // + 2)::self.window_size] + gt_masks = example['mask'][0][:, (self.window_size // + 2)::self.window_size] + batch_id_gt = torch.from_numpy( + np.indices((batch, gt_inds.shape[1]))[0]).to(labels) + scores[batch_id_gt, + gt_inds] = scores[batch_id_gt, gt_inds] + gt_masks + order = scores.sort(1, descending=True)[1] + order = order[:, :self.obj_num] + scores[batch_id_gt, + gt_inds] = scores[batch_id_gt, gt_inds] - gt_masks + else: + order = scores.sort(1, descending=True)[1] + order = order[:, :self.obj_num] + + scores = torch.gather(scores, 1, order) + labels = torch.gather(labels, 1, order) + mask = scores > self.score_threshold + + ct_feat = (x_up.reshape(batch, -1, + H * W).transpose(2, + 1).contiguous()[self.batch_id, + order] + ) # B, 500, C + + # create position embedding for each center + y_coor = order // W + x_coor = order - y_coor * W + y_coor, x_coor = y_coor.to(ct_feat), x_coor.to(ct_feat) + y_coor, x_coor = y_coor / H, x_coor / W + pos_features = torch.stack([x_coor, y_coor], dim=2) + + if self.parametric_embedding: + ct_feat = self.query_embed.weight + ct_feat = ct_feat.unsqueeze(0).expand(batch, -1, -1) + + # run transformer + src_list = [ + x_up.reshape(batch, -1, H * W).transpose(2, 1).contiguous(), + x.reshape(batch, -1, (H * W) // 4).transpose(2, 1).contiguous(), + x_down.reshape(batch, -1, (H * W) // 16).transpose(2, + 1).contiguous(), + ] + for frame in range(x_prev.shape[1]): + src_list.append(x_prev[:, frame].reshape(batch, + -1, (H * W)).transpose( + 2, 1).contiguous()) + src = torch.cat(src_list, dim=1) # B ,sum(H*W), C + spatial_list = [(H, W), (H // 2, W // 2), (H // 4, W // 4)] + spatial_list += [(H, W) for frame in range(x_prev.shape[1])] + spatial_shapes = torch.as_tensor( + spatial_list, dtype=torch.long, device=ct_feat.device) + level_start_index = torch.cat(( + spatial_shapes.new_zeros((1, )), + spatial_shapes.prod(1).cumsum(0)[:-1], + )) + + transformer_out = self.transformer_layer( + ct_feat, + self.pos_embedding, + src, + spatial_shapes, + level_start_index, + center_pos=pos_features, + ) # (B,N,C) + + ct_feat = (transformer_out['ct_feat'].transpose(2, 1).contiguous() + ) # B, C, 500 + + out_dict = { + 'hm': hm, + 'scores': scores, + 'labels': labels, + 'order': order, + 'ct_feat': ct_feat, + 'mask': mask, + } + if 'out_attention' in transformer_out: + out_dict.update( + {'out_attention': transformer_out['out_attention']}) + if self.corner and self.corner_head.training: + out_dict.update({'corner_hm': corner_hm}) + + return out_dict diff --git a/projects/centerformer/centerformer/utils/__init__.py b/projects/centerformer/centerformer/utils/__init__.py new file mode 100644 index 0000000000..5509322aab --- /dev/null +++ b/projects/centerformer/centerformer/utils/__init__.py @@ -0,0 +1,11 @@ +from .attention import ChannelAttention, SpatialAttention +from .bbox_ops import boxes_iou3d_gpu_pcdet, rotate_nms_pcdet +from .multi_scale_deform_attn import MSDeformAttn +from .sparse_block import BasicBlockBias +from .transformer import Deform_Transformer + +__all__ = [ + 'ChannelAttention', 'SpatialAttention', 'BasicBlockBias', 'MSDeformAttn', + 'MSDeformAttn', 'Deform_Transformer', 'boxes_iou3d_gpu_pcdet', + 'rotate_nms_pcdet' +] diff --git a/projects/centerformer/centerformer/utils/attention.py b/projects/centerformer/centerformer/utils/attention.py new file mode 100644 index 0000000000..fa2b7d765a --- /dev/null +++ b/projects/centerformer/centerformer/utils/attention.py @@ -0,0 +1,42 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +from mmengine.model import BaseModule +from torch import nn + + +class ChannelAttention(BaseModule): + + def __init__(self, in_planes, ratio=16): + super(ChannelAttention, self).__init__() + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + + self.fc = nn.Sequential( + nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), + nn.ReLU(), + nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False), + ) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = self.fc(self.avg_pool(x)) + max_out = self.fc(self.max_pool(x)) + out = avg_out + max_out + return self.sigmoid(out) * x + + +class SpatialAttention(BaseModule): + + def __init__(self, kernel_size=7): + super(SpatialAttention, self).__init__() + + self.conv1 = nn.Conv2d( + 2, 1, kernel_size, padding=kernel_size // 2, bias=False) + self.sigmoid = nn.Sigmoid() + + def forward(self, x): + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + y = torch.cat([avg_out, max_out], dim=1) + y = self.conv1(y) + return self.sigmoid(y) * x diff --git a/projects/centerformer/centerformer/utils/bbox_ops.py b/projects/centerformer/centerformer/utils/bbox_ops.py new file mode 100644 index 0000000000..fee864ecb9 --- /dev/null +++ b/projects/centerformer/centerformer/utils/bbox_ops.py @@ -0,0 +1,76 @@ +import torch + +from . import iou3d_nms_cuda + + +def rotate_nms_pcdet(boxes, + scores, + thresh, + pre_maxsize=None, + post_max_size=None): + """ + :param boxes: (N, 5) [x, y, z, l, w, h, theta] + :param scores: (N) + :param thresh: + :return: + """ + # transform back to pcdet's coordinate + # boxes = boxes[:, [0, 1, 2, 4, 3, 5, -1]] + # boxes[:, -1] = -boxes[:, -1] - np.pi / 2 + + order = scores.sort(0, descending=True)[1] + if pre_maxsize is not None: + order = order[:pre_maxsize] + + boxes = boxes[order].contiguous() + + keep = torch.LongTensor(boxes.size(0)) + + if len(boxes) == 0: + num_out = 0 + else: + num_out = iou3d_nms_cuda.nms_gpu(boxes, keep, thresh) + + selected = order[keep[:num_out].cuda()].contiguous() + + if post_max_size is not None: + selected = selected[:post_max_size] + + return selected + + +def boxes_iou3d_gpu_pcdet(boxes_a, boxes_b): + """ + Args: + boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading] + boxes_b: (N, 7) [x, y, z, dx, dy, dz, heading] + Returns: + ans_iou: (N, M) + """ + assert boxes_a.shape[1] == boxes_b.shape[1] == 7 + + # height overlap + boxes_a_height_max = (boxes_a[:, 2] + boxes_a[:, 5] / 2).view(-1, 1) + boxes_a_height_min = (boxes_a[:, 2] - boxes_a[:, 5] / 2).view(-1, 1) + boxes_b_height_max = (boxes_b[:, 2] + boxes_b[:, 5] / 2).view(1, -1) + boxes_b_height_min = (boxes_b[:, 2] - boxes_b[:, 5] / 2).view(1, -1) + + # bev overlap + overlaps_bev = torch.cuda.FloatTensor( + torch.Size((boxes_a.shape[0], boxes_b.shape[0]))).zero_() # (N, M) + iou3d_nms_cuda.boxes_overlap_bev_gpu(boxes_a.contiguous(), + boxes_b.contiguous(), overlaps_bev) + + max_of_min = torch.max(boxes_a_height_min, boxes_b_height_min) + min_of_max = torch.min(boxes_a_height_max, boxes_b_height_max) + overlaps_h = torch.clamp(min_of_max - max_of_min, min=0) + + # 3d iou + overlaps_3d = overlaps_bev * overlaps_h + + vol_a = (boxes_a[:, 3] * boxes_a[:, 4] * boxes_a[:, 5]).view(-1, 1) + vol_b = (boxes_b[:, 3] * boxes_b[:, 4] * boxes_b[:, 5]).view(1, -1) + + iou3d = overlaps_3d / torch.clamp(vol_a + vol_b - overlaps_3d, min=1e-6) + + return iou3d diff --git a/projects/centerformer/centerformer/utils/multi_scale_deform_attn.py b/projects/centerformer/centerformer/utils/multi_scale_deform_attn.py new file mode 100644 index 0000000000..ca2467ac88 --- /dev/null +++ b/projects/centerformer/centerformer/utils/multi_scale_deform_attn.py @@ -0,0 +1,203 @@ +import math + +import torch +import torch.nn.functional as F +from mmcv.utils import ext_loader +from torch import nn +from torch.autograd.function import Function, once_differentiable +from torch.nn.init import constant_, xavier_uniform_ + +ext_module = ext_loader.load_ext( + '_ext', ['ms_deform_attn_backward', 'ms_deform_attn_forward']) + + +class MultiScaleDeformableAttnFunction(Function): + + @staticmethod + def forward(ctx, value: torch.Tensor, value_spatial_shapes: torch.Tensor, + value_level_start_index: torch.Tensor, + sampling_locations: torch.Tensor, + attention_weights: torch.Tensor, + im2col_step: torch.Tensor) -> torch.Tensor: + """GPU/MLU version of multi-scale deformable attention. + + Args: + value (torch.Tensor): The value has shape + (bs, num_keys, mum_heads, embed_dims//num_heads) + value_spatial_shapes (torch.Tensor): Spatial shape of + each feature map, has shape (num_levels, 2), + last dimension 2 represent (h, w) + sampling_locations (torch.Tensor): The location of sampling points, + has shape + (bs ,num_queries, num_heads, num_levels, num_points, 2), + the last dimension 2 represent (x, y). + attention_weights (torch.Tensor): The weight of sampling points + used when calculate the attention, has shape + (bs ,num_queries, num_heads, num_levels, num_points), + im2col_step (torch.Tensor): The step used in image to column. + Returns: + torch.Tensor: has shape (bs, num_queries, embed_dims) + """ + + ctx.im2col_step = im2col_step + output = ext_module.ms_deform_attn_forward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + im2col_step=ctx.im2col_step) + ctx.save_for_backward(value, value_spatial_shapes, + value_level_start_index, sampling_locations, + attention_weights) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output: torch.Tensor) -> tuple: + """GPU/MLU version of backward function. + + Args: + grad_output (torch.Tensor): Gradient of output tensor of forward. + Returns: + tuple[Tensor]: Gradient of input tensors in forward. + """ + value, value_spatial_shapes, value_level_start_index,\ + sampling_locations, attention_weights = ctx.saved_tensors + grad_value = torch.zeros_like(value) + grad_sampling_loc = torch.zeros_like(sampling_locations) + grad_attn_weight = torch.zeros_like(attention_weights) + + ext_module.ms_deform_attn_backward( + value, + value_spatial_shapes, + value_level_start_index, + sampling_locations, + attention_weights, + grad_output.contiguous(), + grad_value, + grad_sampling_loc, + grad_attn_weight, + im2col_step=ctx.im2col_step) + + return grad_value, None, None, \ + grad_sampling_loc, grad_attn_weight, None + + +class MSDeformAttn(nn.Module): + + def __init__(self, + d_model=256, + d_head=64, + n_levels=4, + n_heads=8, + n_points=4, + out_sample_loc=False): + """Multi-Scale Deformable Attention Module. + + :param d_model hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level # noqa: E501 + """ + super().__init__() + + self.im2col_step = 64 + + self.d_model = d_model + self.d_head = d_head + self.n_levels = n_levels + self.n_heads = n_heads + self.n_points = n_points + + self.out_sample_loc = out_sample_loc + + self.sampling_offsets = nn.Linear(d_model, + n_heads * n_levels * n_points * 2) + self.attention_weights = nn.Linear(d_model, + n_heads * n_levels * n_points) + self.value_proj = nn.Linear(d_model, d_head * n_heads) + self.output_proj = nn.Linear(d_head * n_heads, d_model) + + self._reset_parameters() + + def _reset_parameters(self): + constant_(self.sampling_offsets.weight.data, 0.) + thetas = torch.arange( + self.n_heads, dtype=torch.float32) * (2.0 * math.pi / self.n_heads) + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) + grid_init = (grid_init / + grid_init.abs().max(-1, keepdim=True)[0]).view( + self.n_heads, 1, 1, 2).repeat(1, self.n_levels, + self.n_points, 1) + for i in range(self.n_points): + grid_init[:, :, i, :] *= i + 1 + with torch.no_grad(): + self.sampling_offsets.bias = nn.Parameter(grid_init.view(-1)) + constant_(self.attention_weights.weight.data, 0.) + constant_(self.attention_weights.bias.data, 0.) + xavier_uniform_(self.value_proj.weight.data) + constant_(self.value_proj.bias.data, 0.) + xavier_uniform_(self.output_proj.weight.data) + constant_(self.output_proj.bias.data, 0.) + + def forward(self, + query, + reference_points, + input_flatten, + input_spatial_shapes, + input_level_start_index, + input_padding_mask=None): + """ + :param query: (N, Length_{query}, C) + :param reference_points + (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area # noqa: E501 + or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes # noqa: E501 + :param input_flatten : (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) + :param input_spatial_shapes: + (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] + :param input_level_start_index: (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] # noqa: E501 + :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements # noqa: E501 + :return output (N, Length_{query}, C) + """ + N, Len_q, _ = query.shape + N, Len_in, _ = input_flatten.shape + assert (input_spatial_shapes[:, 0] * + input_spatial_shapes[:, 1]).sum() == Len_in + + value = self.value_proj(input_flatten) + if input_padding_mask is not None: + value = value.masked_fill(input_padding_mask[..., None], float(0)) + value = value.view(N, Len_in, self.n_heads, self.d_head) + sampling_offsets = self.sampling_offsets(query).view( + N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) + attention_weights = self.attention_weights(query).view( + N, Len_q, self.n_heads, self.n_levels * self.n_points) + attention_weights = F.softmax(attention_weights, + -1).view(N, Len_q, self.n_heads, + self.n_levels, self.n_points) + # N, Len_q, n_heads, n_levels, n_points, 2 + if reference_points.shape[-1] == 2: + offset_normalizer = torch.stack( + [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], + -1).to(sampling_offsets) + + sampling_locations = reference_points[:, :, None, :, None, :] \ + + sampling_offsets / offset_normalizer[None, None, None, :, None, :] # noqa: E501 + elif reference_points.shape[-1] == 4: + sampling_locations = reference_points[:, :, None, :, None, :2] \ + + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 # noqa: E501 + else: + raise ValueError( + 'Last dim of reference_points must be 2 or 4, but get {} instead.' # noqa: E501 + .format(reference_points.shape[-1])) + output = MultiScaleDeformableAttnFunction.apply( + value, input_spatial_shapes, input_level_start_index, + sampling_locations, attention_weights, self.im2col_step) + output = self.output_proj(output) + if self.out_sample_loc: + return output, torch.cat( + (sampling_locations, attention_weights[:, :, :, :, :, None]), + dim=-1) + else: + return output, None diff --git a/projects/centerformer/centerformer/utils/sparse_block.py b/projects/centerformer/centerformer/utils/sparse_block.py new file mode 100644 index 0000000000..6d9221aeab --- /dev/null +++ b/projects/centerformer/centerformer/utils/sparse_block.py @@ -0,0 +1,88 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer +from mmengine.model import BaseModule +from torch import nn + + +class BasicBlockBias(BaseModule): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_cfg=None): + super(BasicBlockBias, self).__init__(init_cfg) + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=True) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, planes, planes, 3, padding=1, bias=True) + self.add_module(self.norm2_name, norm2) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out diff --git a/projects/centerformer/centerformer/utils/transformer.py b/projects/centerformer/centerformer/utils/transformer.py new file mode 100644 index 0000000000..f7b06e186c --- /dev/null +++ b/projects/centerformer/centerformer/utils/transformer.py @@ -0,0 +1,421 @@ +import math + +import torch +from einops import rearrange +from torch import einsum, nn +from torch.nn import functional as F + +from .multi_scale_deform_attn import MSDeformAttn + + +class MLP(nn.Module): + """Very simple multi-layer perceptron (also called FFN)""" + + def __init__(self, input_dim, hidden_dim, output_dim, num_layers): + super().__init__() + self.num_layers = num_layers + h = [hidden_dim] * (num_layers - 1) + self.layers = nn.ModuleList( + nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) + + def forward(self, x): + for i, layer in enumerate(self.layers): + x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) + return x + + +class GELU(nn.Module): + + def forward(self, x): + return 0.5 * x * (1 + torch.tanh( + math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) + + +# transformer layer +class PreNorm(nn.Module): + + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x, **kwargs): + return self.fn(self.norm(x), **kwargs) + + +class PreNorm_CA(nn.Module): + + def __init__(self, dim, fn): + super().__init__() + self.norm = nn.LayerNorm(dim) + self.fn = fn + + def forward(self, x, y, **kwargs): + return self.fn(self.norm(x), self.norm(y), **kwargs) + + +class FeedForward(nn.Module): + + def __init__(self, dim, hidden_dim, dropout=0.0): + super().__init__() + self.net = nn.Sequential( + nn.Linear(dim, hidden_dim), + GELU(), + nn.Dropout(dropout), + nn.Linear(hidden_dim, dim), + nn.Dropout(dropout), + ) + + def forward(self, x): + return self.net(x) + + +class Attention(nn.Module): + + def __init__(self, + dim, + heads=8, + dim_head=64, + dropout=0.0, + out_attention=False): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + self.out_attention = out_attention + + self.attend = nn.Softmax(dim=-1) + self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) + + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out else nn.Identity()) + + def forward(self, x): + b, n, _, h = *x.shape, self.heads + qkv = self.to_qkv(x).chunk(3, dim=-1) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = self.attend(dots) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, 'b h n d -> b n (h d)') + + if self.out_attention: + return self.to_out(out), attn + else: + return self.to_out(out) + + +class Cross_attention(nn.Module): + + def __init__(self, + dim, + heads=8, + dim_head=64, + dropout=0.0, + out_attention=False): + super().__init__() + inner_dim = dim_head * heads + project_out = not (heads == 1 and dim_head == dim) + + self.heads = heads + self.scale = dim_head**-0.5 + self.out_attention = out_attention + + self.attend = nn.Softmax(dim=-1) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_q = nn.Linear(dim, inner_dim, bias=False) + + self.to_out = ( + nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) + if project_out else nn.Identity()) + + def forward(self, x, y): + b, n, m, _, h = *y.shape, self.heads + q = self.to_q(x) + kv = self.to_kv(y).chunk(2, dim=-1) + q = rearrange(q, 'b n (h d) -> (b n) h 1 d', h=h) + k, v = map(lambda t: rearrange(t, 'b n m (h d) -> (b n) h m d', h=h), + kv) + + dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale + + attn = self.attend(dots) + + out = einsum('b h i j, b h j d -> b h i d', attn, v) + out = rearrange(out, '(b n) h 1 d -> b n (h d)', b=b) + + if self.out_attention: + return self.to_out(out), rearrange( + attn, '(b n) h i j -> b n h (i j)', b=b) + else: + return self.to_out(out) + + +class DeformableTransformerCrossAttention(nn.Module): + + def __init__( + self, + d_model=256, + d_head=64, + dropout=0.3, + n_levels=3, + n_heads=6, + n_points=9, + out_sample_loc=False, + ): + super().__init__() + + # cross attention + self.cross_attn = MSDeformAttn( + d_model, + d_head, + n_levels, + n_heads, + n_points, + out_sample_loc=out_sample_loc) + self.dropout = nn.Dropout(dropout) + self.out_sample_loc = out_sample_loc + + @staticmethod + def with_pos_embed(tensor, pos): + return tensor if pos is None else tensor + pos + + def forward( + self, + tgt, + src, + query_pos=None, + reference_points=None, + src_spatial_shapes=None, + level_start_index=None, + src_padding_mask=None, + ): + # cross attention + tgt2, sampling_locations = self.cross_attn( + self.with_pos_embed(tgt, query_pos), + reference_points, + src, + src_spatial_shapes, + level_start_index, + src_padding_mask, + ) + tgt = self.dropout(tgt2) + + if self.out_sample_loc: + return tgt, sampling_locations + else: + return tgt + + +class Transformer(nn.Module): + + def __init__( + self, + dim, + depth=2, + heads=4, + dim_head=64, + mlp_dim=256, + dropout=0.0, + out_attention=False, + ): + super().__init__() + self.out_attention = out_attention + self.layers = nn.ModuleList([]) + self.depth = depth + + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PreNorm( + dim, + Attention( + dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + out_attention=self.out_attention, + ), + ), + PreNorm_CA( + dim, + Cross_attention( + dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + out_attention=self.out_attention, + ), + ), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)), + ])) + + def forward(self, + x, + pos_embedding=None, + center_pos=None, + y=None, + neighbor_pos=None): + if self.out_attention: + out_cross_attention_list = [] + if center_pos is not None and pos_embedding is not None: + center_pos_embedding = pos_embedding(center_pos) + if neighbor_pos is not None and pos_embedding is not None: + neighbor_pos_embedding = pos_embedding(neighbor_pos) + for i, (self_attn, cross_attn, ff) in enumerate(self.layers): + if self.out_attention: + if pos_embedding is not None: + x_att, self_att = self_attn(x + center_pos_embedding) + x = x_att + x + x_att, cross_att = cross_attn(x + center_pos_embedding, + y + neighbor_pos_embedding) + else: + x_att, self_att = self_attn(x) + x = x_att + x + x_att, cross_att = cross_attn(x, y) + out_cross_attention_list.append(cross_att) + else: + if pos_embedding is not None: + x_att = self_attn(x + center_pos_embedding) + x = x_att + x + x_att = cross_attn(x + center_pos_embedding, + y + neighbor_pos_embedding) + else: + x_att = self_attn(x) + x = x_att + x + x_att = cross_attn(x, y) + + x = x_att + x + x = ff(x) + x + + out_dict = {'ct_feat': x} + if self.out_attention: + out_dict.update({ + 'out_attention': + torch.stack(out_cross_attention_list, dim=2) + }) + return out_dict + + +class Deform_Transformer(nn.Module): + + def __init__( + self, + dim, + levels=3, + depth=2, + heads=4, + dim_head=32, + mlp_dim=256, + dropout=0.0, + out_attention=False, + n_points=9, + ): + super().__init__() + self.out_attention = out_attention + self.layers = nn.ModuleList([]) + self.depth = depth + self.levels = levels + self.n_points = n_points + + for _ in range(depth): + self.layers.append( + nn.ModuleList([ + PreNorm( + dim, + Attention( + dim, + heads=heads, + dim_head=dim_head, + dropout=dropout, + out_attention=self.out_attention, + ), + ), + PreNorm_CA( + dim, + DeformableTransformerCrossAttention( + dim, + dim_head, + n_levels=levels, + n_heads=heads, + dropout=dropout, + n_points=n_points, + out_sample_loc=self.out_attention, + ), + ), + PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)), + ])) + + def forward(self, x, pos_embedding, src, src_spatial_shapes, + level_start_index, center_pos): + if self.out_attention: + out_cross_attention_list = [] + if pos_embedding is not None: + center_pos_embedding = pos_embedding(center_pos) + reference_points = center_pos[:, :, + None, :].repeat(1, 1, self.levels, 1) + for i, (self_attn, cross_attn, ff) in enumerate(self.layers): + if self.out_attention: + if center_pos_embedding is not None: + x_att, self_att = self_attn(x + center_pos_embedding) + x = x_att + x + x_att, cross_att = cross_attn( + x, + src, + query_pos=center_pos_embedding, + reference_points=reference_points, + src_spatial_shapes=src_spatial_shapes, + level_start_index=level_start_index, + ) + else: + x_att, self_att = self_attn(x) + x = x_att + x + x_att, cross_att = cross_attn( + x, + src, + query_pos=None, + reference_points=reference_points, + src_spatial_shapes=src_spatial_shapes, + level_start_index=level_start_index, + ) + out_cross_attention_list.append(cross_att) + else: + if center_pos_embedding is not None: + x_att = self_attn(x + center_pos_embedding) + x = x_att + x + x_att = cross_attn( + x, + src, + query_pos=center_pos_embedding, + reference_points=reference_points, + src_spatial_shapes=src_spatial_shapes, + level_start_index=level_start_index, + ) + else: + x_att = self_attn(x) + x = x_att + x + x_att = cross_attn( + x, + src, + query_pos=None, + reference_points=reference_points, + src_spatial_shapes=src_spatial_shapes, + level_start_index=level_start_index, + ) + + x = x_att + x + x = ff(x) + x + + out_dict = {'ct_feat': x} + if self.out_attention: + out_dict.update({ + 'out_attention': + torch.stack(out_cross_attention_list, dim=2) + }) + return out_dict diff --git a/projects/centerformer/configs/centerform_voxel01.py b/projects/centerformer/configs/centerform_voxel01.py new file mode 100644 index 0000000000..fad23abeae --- /dev/null +++ b/projects/centerformer/configs/centerform_voxel01.py @@ -0,0 +1,309 @@ +_base_ = ['mmdet3d::_base_/default_runtime.py'] + +# model settings +# Voxel size for voxel encoder +# Usually voxel size is changed consistently with the point cloud range +# If point cloud range is modified, do remember to change all related +# keys in the config. +voxel_size = [0.1, 0.1, 0.15] +point_cloud_range = [-75.2, -75.2, -2, 75.2, 75.2, 4] +class_names = ['Car', 'Pedestrian', 'Cyclist'] +tasks = [dict(num_class=3, class_names=['car', 'pedestrian', 'cyclist'])] +metainfo = dict(classes=class_names) +input_modality = dict(use_lidar=True, use_camera=False) +file_client_args = dict(backend='disk') + +model = dict( + type='CenterFormer', + data_preprocessor=dict( + type='Det3DDataPreprocessor', + voxel=True, + voxel_type='dynamic', + voxel_layer=dict( + max_num_points=-1, + point_cloud_range=point_cloud_range, + voxel_size=voxel_size, + max_voxels=(-1, -1))), + voxel_encoder=dict( + type='DynamicSimpleVFE', + point_cloud_range=point_cloud_range, + voxel_size=voxel_size), + middle_encoder=dict( + type='SparseEncoder', + in_channels=5, + sparse_shape=[41, 1504, 1504], + order=('conv', 'norm', 'act'), + norm_cfg=dict(type='naiveSyncBN1d', eps=0.001, momentum=0.01), + encoder_channels=((16, 16, 32), (32, 32, 64), (64, 64, 128), (128, + 128)), + encoder_paddings=((1, 1, 1), (1, 1, 1), (1, 1, [0, 1, 1]), (1, 1)), + block_type='basicblock'), + backbone=dict( + type='RPN_transformer_deformable', + layer_nums=[5, 5, 1], + ds_num_filters=[256, 256, 128], + num_input_features=256, + tasks=tasks, + use_gt_training=True, + corner=True, + assign_label_window_size=1, + obj_num=500, + norm_cfg=dict(type='SyncBN', eps=1e-3, momentum=0.01), + transformer_config=dict( + depth=2, + heads=6, + dim_head=64, + MLP_dim=256, + DP_rate=0.3, + out_att=False, + n_points=15, + ), + ), + bbox_head=dict( + type='CenterHeadIoU_1d', + in_channels=256, + tasks=tasks, + dataset='waymo', + weight=2, + corner_loss=True, + iou_loss=True, + assign_label_window_size=1, + norm_cfg=dict(type='SyncBN', eps=1e-3, momentum=0.01), + code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], + common_heads={ + 'reg': (2, 2), + 'height': (1, 2), + 'dim': (3, 2), + 'rot': (2, 2), + 'iou': (1, 2) + }, # (output_channel, num_conv) + ), + train_cfg=dict( + grid_size=[1504, 1504, 40], + voxel_size=voxel_size, + out_size_factor=4, + dense_reg=1, + gaussian_overlap=0.1, + point_cloud_range=point_cloud_range, + max_objs=500, + min_radius=2, + code_weights=[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]), + test_cfg=dict( + post_center_limit_range=[-80, -80, -10.0, 80, 80, 10.0], + nms=dict( + use_rotate_nms=False, + use_multi_class_nms=True, + nms_pre_max_size=[1600, 1600, 800], + nms_post_max_size=[200, 200, 100], + nms_iou_threshold=[0.8, 0.55, 0.55], + ), + score_threshold=0.1, + pc_range=[-75.2, -75.2], + out_size_factor=4, + voxel_size=[0.1, 0.1], + obj_num=1000, + )) + +data_root = 'data/waymo/kitti_format/' +db_sampler = dict( + data_root=data_root, + info_path=data_root + 'waymo_dbinfos_train.pkl', + rate=1.0, + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict(Car=5, Pedestrian=5, Cyclist=5)), + classes=class_names, + sample_groups=dict(Car=15, Pedestrian=10, Cyclist=10), + points_loader=dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=6, + use_dim=[0, 1, 2, 3, 4])) + +train_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=6, + use_dim=5, + norm_intensity=True), + # dict( + # type='LoadPointsFromMultiSweeps', + # sweeps_num=9, + # load_dim=6, + # use_dim=[0, 1, 2, 3, 4], + # pad_empty_sweeps=True, + # remove_close=True), + dict(type='LoadAnnotations3D', with_bbox_3d=True, with_label_3d=True), + dict(type='ObjectSample', db_sampler=db_sampler), + dict( + type='GlobalRotScaleTrans', + rot_range=[-0.78539816, 0.78539816], + scale_ratio_range=[0.95, 1.05], + translation_std=[0.5, 0.5, 0]), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectRangeFilter', point_cloud_range=point_cloud_range), + dict(type='ObjectNameFilter', classes=class_names), + dict(type='PointShuffle'), + dict( + type='Pack3DDetInputs', + keys=['points', 'gt_bboxes_3d', 'gt_labels_3d']) +] + +test_pipeline = [ + dict( + type='LoadPointsFromFile', + coord_type='LIDAR', + load_dim=6, + use_dim=5, + norm_intensity=True, + file_client_args=file_client_args), + dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), + # dict( + # type='MultiScaleFlipAug3D', + # img_scale=(1333, 800), + # pts_scale_ratio=1, + # flip=False, + # transforms=[ + # dict( + # type='GlobalRotScaleTrans', + # rot_range=[0, 0], + # scale_ratio_range=[1., 1.], + # translation_std=[0, 0, 0]), + # dict(type='RandomFlip3D'), + # dict( + # type='PointsRangeFilter', point_cloud_range=point_cloud_range) + # ]), + dict(type='Pack3DDetInputs', keys=['points']) +] + +dataset_type = 'WaymoDataset' +train_dataloader = dict( + batch_size=4, + num_workers=4, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='waymo_infos_train.pkl', + data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'), + pipeline=train_pipeline, + modality=input_modality, + test_mode=False, + metainfo=metainfo, + # we use box_type_3d='LiDAR' in kitti and nuscenes dataset + # and box_type_3d='Depth' in sunrgbd and scannet dataset. + box_type_3d='LiDAR', + # load one frame every five frames + load_interval=5, + file_client_args=file_client_args)) +val_dataloader = dict( + batch_size=1, + num_workers=1, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + data_prefix=dict(pts='training/velodyne', sweeps='training/velodyne'), + ann_file='waymo_infos_val.pkl', + pipeline=test_pipeline, + modality=input_modality, + test_mode=True, + metainfo=metainfo, + box_type_3d='LiDAR', + file_client_args=file_client_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type='WaymoMetric', + ann_file='./data/waymo/kitti_format/waymo_infos_val.pkl', + waymo_bin_file='./data/waymo/waymo_format/gt.bin', + data_root='./data/waymo/waymo_format', + file_client_args=file_client_args, + convert_kitti_format=False, + idx2metainfo='./data/waymo/waymo_format/idx2metainfo.pkl') +test_evaluator = val_evaluator + +vis_backends = [dict(type='LocalVisBackend')] +visualizer = dict( + type='Det3DLocalVisualizer', vis_backends=vis_backends, name='visualizer') + +# For waymo dataset, we usually evaluate the model at the end of training. +# Since the models are trained by 24 epochs by default, we set evaluation +# interval to be 20. Please change the interval accordingly if you do not +# use a default schedule. +# optimizer +lr = 3e-4 +# This schedule is mainly used by models on nuScenes dataset +# max_norm=10 is better for SECOND +optim_wrapper = dict( + type='OptimWrapper', + optimizer=dict(type='AdamW', lr=lr, weight_decay=0.01, betas=(0.9, 0.99)), + clip_grad=dict(max_norm=35, norm_type=2)) +# learning rate +param_scheduler = [ + # learning rate scheduler + # During the first 8 epochs, learning rate increases from 0 to lr * 10 + # during the next 12 epochs, learning rate decreases from lr * 10 to + # lr * 1e-4 + dict( + type='CosineAnnealingLR', + T_max=8, + eta_min=lr * 10, + begin=0, + end=8, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='CosineAnnealingLR', + T_max=12, + eta_min=lr * 1e-4, + begin=8, + end=20, + by_epoch=True, + convert_to_iter_based=True), + # momentum scheduler + # During the first 8 epochs, momentum increases from 0 to 0.85 / 0.95 + # during the next 12 epochs, momentum increases from 0.85 / 0.95 to 1 + dict( + type='CosineAnnealingMomentum', + T_max=8, + eta_min=0.85 / 0.95, + begin=0, + end=8, + by_epoch=True, + convert_to_iter_based=True), + dict( + type='CosineAnnealingMomentum', + T_max=12, + eta_min=1, + begin=8, + end=20, + by_epoch=True, + convert_to_iter_based=True) +] + +# runtime settings +train_cfg = dict(by_epoch=True, max_epochs=20, val_interval=20) +val_cfg = dict() +test_cfg = dict() + +# Default setting for scaling LR automatically +# - `enable` means enable scaling LR automatically +# or not by default. +# - `base_batch_size` = (4 GPUs) x (4 samples per GPU). +auto_scale_lr = dict(enable=False, base_batch_size=16) + +default_hooks = dict( + logger=dict( + type='LoggerHook', + interval=50, + ), + checkpoint=dict(type='CheckpointHook', interval=5)) +custom_hooks = [dict(type='DisableObjectSampleHook', disable_after_epoch=15)] +# load_from = 'checkpoints/init_centerformer_converted.pth' +# load_from = 'checkpoints/centerformer_our_init.pth' +load_from = None From 534224c9fd7fb9c930d05ec14862dacfbc03df63 Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Mon, 26 Dec 2022 13:02:42 +0800 Subject: [PATCH 02/18] add readme and disable_tf32_switch --- projects/centerformer/README.md | 97 +++++++++++++++++++ .../centerformer/rpn_transformer.py | 2 +- ...atten_4xb4-cyclic-20e_waymoD5-3d-class.py} | 34 +++---- tools/train.py | 10 ++ 4 files changed, 123 insertions(+), 20 deletions(-) create mode 100644 projects/centerformer/README.md rename projects/centerformer/configs/{centerform_voxel01.py => centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py} (92%) diff --git a/projects/centerformer/README.md b/projects/centerformer/README.md new file mode 100644 index 0000000000..fc862da7fc --- /dev/null +++ b/projects/centerformer/README.md @@ -0,0 +1,97 @@ +# CenterFormer: Center-based Transformer for 3D Object Detection + +> [CenterFormer: Center-based Transformer for 3D Object Detection](https://arxiv.org/abs/2209.05588) + + + +## Abstract + +Query-based transformer has shown great potential in con- +structing long-range attention in many image-domain tasks, but has +rarely been considered in LiDAR-based 3D object detection due to the +overwhelming size of the point cloud data. In this paper, we propose +CenterFormer, a center-based transformer network for 3D object de- +tection. CenterFormer first uses a center heatmap to select center candi- +dates on top of a standard voxel-based point cloud encoder. It then uses +the feature of the center candidate as the query embedding in the trans- +former. To further aggregate features from multiple frames, we design +an approach to fuse features through cross-attention. Lastly, regression +heads are added to predict the bounding box on the output center feature +representation. Our design reduces the convergence difficulty and compu- +tational complexity of the transformer structure. The results show signif- +icant improvements over the strong baseline of anchor-free object detec- +tion networks. CenterFormer achieves state-of-the-art performance for a +single model on the Waymo Open Dataset, with 73.7% mAPH on the val- +idation set and 75.6% mAPH on the test set, significantly outperforming +all previously published CNN and transformer-based methods. Our code +is publicly available at https://github.com/TuSimple/centerformer + +
+ +
+ +## Introduction + +We implement CenterFormer and provide the result and checkpoints on Waymo dataset. + +We follow the below style to name config files. Contributors are advised to follow the same style. +`{xxx}` is required field and `[yyy]` is optional. + +`{model}`: model type like `centerpoint`. + +`{model setting}`: voxel size and voxel type like `01voxel`, `02pillar`. + +`{backbone}`: backbone type like `second`. + +`{neck}`: neck type like `secfpn`. + +`[batch_per_gpu x gpu]`: GPUs and samples per GPU, 4x8 is used by default. + +`{schedule}`: training schedule, options are 1x, 2x, 20e, etc. 1x and 2x means 12 epochs and 24 epochs respectively. 20e is adopted in cascade models, which denotes 20 epochs. For 1x/2x, initial learning rate decays by a factor of 10 at the 8/16th and 11/22th epochs. For 20e, initial learning rate decays by a factor of 10 at the 16th and 19th epochs. + +`{dataset}`: dataset like nus-3d, kitti-3d, lyft-3d, scannet-3d, sunrgbd-3d. We also indicate the number of classes we are using if there exist multiple settings, e.g., kitti-3d-3class and kitti-3d-car means training on KITTI dataset with 3 classes and single class, respectively. + +## Usage + + + +### Training commands + +In MMDetection3D's root directory, run the following command to train the model: + +```bash +python tools/train.py projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py +``` + +For multi-gpu training, run: + +```bash +python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=${NUM_GPUS} --master_port=29506 --master_addr="127.0.0.1" tools/train.py projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py +``` + +### Testing commands + +In MMDetection's root directory, run the following command to test the model: + +```bash +python tools/train.py projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py ${CHECKPOINT_PATH} +``` + +## Results and models + +### CenterFormer + +| Backbone | Voxel type (voxel size) | Multi-Class NMS | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download | +| :-------------------------------------------------------------------------------------: | :---------------------: | :-------------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :----------------------: | +| [SECFPN_WithAtten](./centerpoint_01voxel_second_secfpn_circlenms_4x8_cyclic_20e_nus.py) | voxel (0.1) | ✓ | 5.2 | | | | | | [model](<>) \| [log](<>) | + +## Citation + +```latex +@InProceedings{Zhou_centerformer, +title = {CenterFormer: Center-based Transformer for 3D Object Detection}, +author = {Zhou, Zixiang and Zhao, Xiangchen and Wang, Yu and Wang, Panqu and Foroosh, Hassan}, +booktitle = {ECCV}, +year = {2022} +} +``` diff --git a/projects/centerformer/centerformer/rpn_transformer.py b/projects/centerformer/centerformer/rpn_transformer.py index 7379e47f96..89f7720bbb 100644 --- a/projects/centerformer/centerformer/rpn_transformer.py +++ b/projects/centerformer/centerformer/rpn_transformer.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Tuple import numpy as np import torch diff --git a/projects/centerformer/configs/centerform_voxel01.py b/projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py similarity index 92% rename from projects/centerformer/configs/centerform_voxel01.py rename to projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py index fad23abeae..f71230f2f9 100644 --- a/projects/centerformer/configs/centerform_voxel01.py +++ b/projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py @@ -158,22 +158,21 @@ use_dim=5, norm_intensity=True, file_client_args=file_client_args), - dict(type='PointsRangeFilter', point_cloud_range=point_cloud_range), - # dict( - # type='MultiScaleFlipAug3D', - # img_scale=(1333, 800), - # pts_scale_ratio=1, - # flip=False, - # transforms=[ - # dict( - # type='GlobalRotScaleTrans', - # rot_range=[0, 0], - # scale_ratio_range=[1., 1.], - # translation_std=[0, 0, 0]), - # dict(type='RandomFlip3D'), - # dict( - # type='PointsRangeFilter', point_cloud_range=point_cloud_range) - # ]), + dict( + type='MultiScaleFlipAug3D', + img_scale=(1333, 800), + pts_scale_ratio=1, + flip=False, + transforms=[ + dict( + type='GlobalRotScaleTrans', + rot_range=[0, 0], + scale_ratio_range=[1., 1.], + translation_std=[0, 0, 0]), + dict(type='RandomFlip3D'), + dict( + type='PointsRangeFilter', point_cloud_range=point_cloud_range) + ]), dict(type='Pack3DDetInputs', keys=['points']) ] @@ -304,6 +303,3 @@ ), checkpoint=dict(type='CheckpointHook', interval=5)) custom_hooks = [dict(type='DisableObjectSampleHook', disable_after_epoch=15)] -# load_from = 'checkpoints/init_centerformer_converted.pth' -# load_from = 'checkpoints/centerformer_our_init.pth' -load_from = None diff --git a/tools/train.py b/tools/train.py index 7c65903c24..fe8ccbfa53 100644 --- a/tools/train.py +++ b/tools/train.py @@ -21,6 +21,11 @@ def parse_args(): action='store_true', default=False, help='enable automatic-mixed-precision training') + parser.add_argument( + '--disable-tf32', + action='store_true', + default=False, + help='disable TF32 in A100 GPUs') parser.add_argument( '--auto-scale-lr', action='store_true', @@ -116,6 +121,11 @@ def main(): # if 'runner_type' is set in the cfg runner = RUNNERS.build(cfg) + if args.disable_tf32: + import torch + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + # start training runner.train() From bb907ab1fec93ebfda2d8f74f80e192cf874e76e Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Mon, 26 Dec 2022 16:28:34 +0800 Subject: [PATCH 03/18] using our iou3d and nms3d, using basicblock in mmdet, simplify code in projects --- mmdet3d/models/layers/sparse_block.py | 5 +- .../centerformer/centerformer_head.py | 53 ++------------ .../centerformer/rpn_transformer.py | 11 +-- .../centerformer/utils/__init__.py | 11 +-- .../centerformer/utils/bbox_ops.py | 71 +++++-------------- tools/train.py | 13 ++-- 6 files changed, 36 insertions(+), 128 deletions(-) diff --git a/mmdet3d/models/layers/sparse_block.py b/mmdet3d/models/layers/sparse_block.py index 3d3dd66274..14fc4deeda 100644 --- a/mmdet3d/models/layers/sparse_block.py +++ b/mmdet3d/models/layers/sparse_block.py @@ -2,13 +2,10 @@ from typing import Tuple, Union from mmcv.cnn import build_conv_layer, build_norm_layer -# from mmdet.models.backbones.resnet import BasicBlock, Bottleneck -from mmdet.models.backbones.resnet import Bottleneck +from mmdet.models.backbones.resnet import BasicBlock, Bottleneck from torch import nn from mmdet3d.utils import OptConfigType -from projects.centerformer.centerformer.utils import \ - BasicBlockBias as BasicBlock # noqa: E501 from .spconv import IS_SPCONV2_AVAILABLE if IS_SPCONV2_AVAILABLE: diff --git a/projects/centerformer/centerformer/centerformer_head.py b/projects/centerformer/centerformer/centerformer_head.py index dc6ea9feb2..f5c197c7a1 100644 --- a/projects/centerformer/centerformer/centerformer_head.py +++ b/projects/centerformer/centerformer/centerformer_head.py @@ -18,9 +18,9 @@ from mmdet3d.models.layers import circle_nms, nms_bev from mmdet3d.registry import MODELS -from mmdet3d.structures import bbox_overlaps_3d from .losses import FastFocalLoss -from .utils import boxes_iou3d_gpu_pcdet, rotate_nms_pcdet +from mmcv.ops import boxes_iou3d +from .utils import nms_iou3d class SepHead(nn.Module): @@ -266,16 +266,8 @@ def loss(self, preds_dicts, example, **kwargs): preds_dict['hm'].shape[2], preds_dict['hm'].shape[3], ) - # iou_targets = bbox_overlaps_3d( - # preds_box.reshape(-1, 7), - # cur_gt.reshape(-1, 7), - # coordinate='lidar')[ - # range(preds_box.reshape(-1, 7).shape[0]), - # range(cur_gt.reshape(-1, 7).shape[0])] - - preds_box[:, :, 2] += preds_box[:, :, 5] / 2 - cur_gt[:, :, 2] += cur_gt[:, :, 5] / 2 - iou_targets = boxes_iou3d_gpu_pcdet( + + iou_targets = boxes_iou3d( preds_box.reshape(-1, 7), cur_gt.reshape( -1, 7))[range(preds_box.reshape(-1, 7).shape[0]), range(cur_gt.reshape(-1, 7).shape[0])] @@ -288,24 +280,11 @@ def loss(self, preds_dicts, example, **kwargs): losses.update( {f'{task_id}_iou_loss': iou_loss * self.iou_weight}) - # loss = hm_loss + self.weight * loc_loss - # if self.use_iou_loss: - # loss = loss + self.iou_weight * iou_loss - # if self.corner_loss: - # loss = loss + self.corner_weight * corner_loss losses.update({ f'{task_id}_hm_loss': hm_loss, - f'{task_id}_loc_loss': loc_loss * self.weight, - # 'loc_loss_elem': box_loss, - # 'num_positive': example['mask'][task_id].float().sum(), + f'{task_id}_loc_loss': loc_loss * self.weight }) - # """convert batch-key to key-batch""" - # rets_merged = defaultdict(list) - # for ret in rets: - # for k, v in ret.items(): - # rets_merged[k].append(v) - return losses @torch.no_grad() @@ -314,9 +293,7 @@ def predict(self, preds_dicts, batch_input_metas, **kwargs): Additionally support double flip testing """ - # get loss info rets = [] - metas = [] post_center_range = self.test_cfg.post_center_limit_range if len(post_center_range) > 0: @@ -333,13 +310,6 @@ def predict(self, preds_dicts, batch_input_metas, **kwargs): if len(preds_dict[key].shape) == 3: preds_dict[key] = val.permute(0, 2, 1).contiguous() - batch_size = preds_dict['scores'].shape[0] - - # if 'metadata' not in example or len(example['metadata']) == 0: - # meta_list = [None] * batch_size - # else: - # meta_list = example['metadata'] - batch_score = preds_dict['scores'] batch_label = preds_dict['labels'] batch_mask = preds_dict['mask'] @@ -347,10 +317,6 @@ def predict(self, preds_dicts, batch_input_metas, **kwargs): batch_iou = preds_dict['iou'].squeeze(2) else: batch_iou = None - if 'corner_hm' in preds_dict: - batch_corner_hm = preds_dict['corner_hm'] - else: - batch_corner_hm = None batch_dim = torch.exp(preds_dict['dim']) @@ -397,8 +363,6 @@ def predict(self, preds_dicts, batch_input_metas, **kwargs): batch_box_preds = torch.cat( [xs, ys, batch_hei, batch_dim, batch_rot], dim=2) - # metas.append(meta_list) - if self.test_cfg.get('per_class_nms', False): pass else: @@ -515,7 +479,7 @@ def post_processing( # pre_max_size=test_cfg.nms.nms_pre_max_size[c], # post_max_size=test_cfg.nms.nms_post_max_size[c], # ) - select = rotate_nms_pcdet( + select = nms_iou3d( boxes_for_nms[class_mask].float(), scores[class_mask].float(), thresh=test_cfg.nms.nms_iou_threshold[c], @@ -582,7 +546,6 @@ def get_box(pred_boxs, order, test_cfg, H, W): pred = torch.cat( [xs, ys, pred_boxs[:, 2:3], torch.exp(pred_boxs[:, 3:6]), rot], dim=1) - pred[:, 2] = pred[:, 2] - pred[:, 5] / 2 return torch.transpose(pred, 1, 2).contiguous() # B M 7 @@ -613,10 +576,6 @@ def get_box_gt(gt_boxs, order, test_cfg, H, W): batch_box_targets = torch.cat( [xs, ys, batch_gt_hei, batch_gt_dim, batch_gt_rot], dim=-1) - batch_box_targets[..., - 2] = batch_box_targets[..., - 2] - batch_box_targets[..., 5] / 2 - return batch_box_targets # B M 7 diff --git a/projects/centerformer/centerformer/rpn_transformer.py b/projects/centerformer/centerformer/rpn_transformer.py index 89f7720bbb..06eedd629f 100644 --- a/projects/centerformer/centerformer/rpn_transformer.py +++ b/projects/centerformer/centerformer/rpn_transformer.py @@ -1,4 +1,3 @@ -# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List, Tuple import numpy as np @@ -209,12 +208,6 @@ def _make_layer(self, inplanes, planes, num_blocks, stride=1): return block, planes - # default init_weights for conv(msra) and norm in ConvModule - # def init_weights(self): - # for m in self.modules(): - # if isinstance(m, nn.Conv2d): - # xavier_init(m, distribution='uniform') - def forward(self, x, example=None): pass @@ -432,10 +425,10 @@ def forward(self, x, batch_data_samples): if self.corner and self.corner_head.training: corner_hm = self.corner_head(x_up) - corner_hm = torch.sigmoid(corner_hm) + corner_hm = self._sigmoid(corner_hm) # find top K center location - hm = torch.sigmoid(hm) + hm = self._sigmoid(hm) batch, num_cls, H, W = hm.size() scores, labels = torch.max( diff --git a/projects/centerformer/centerformer/utils/__init__.py b/projects/centerformer/centerformer/utils/__init__.py index 5509322aab..a7f122aea7 100644 --- a/projects/centerformer/centerformer/utils/__init__.py +++ b/projects/centerformer/centerformer/utils/__init__.py @@ -1,11 +1,4 @@ -from .attention import ChannelAttention, SpatialAttention -from .bbox_ops import boxes_iou3d_gpu_pcdet, rotate_nms_pcdet -from .multi_scale_deform_attn import MSDeformAttn -from .sparse_block import BasicBlockBias +from .bbox_ops import nms_iou3d from .transformer import Deform_Transformer -__all__ = [ - 'ChannelAttention', 'SpatialAttention', 'BasicBlockBias', 'MSDeformAttn', - 'MSDeformAttn', 'Deform_Transformer', 'boxes_iou3d_gpu_pcdet', - 'rotate_nms_pcdet' -] +__all__ = ['Deform_Transformer', 'nms_iou3d'] diff --git a/projects/centerformer/centerformer/utils/bbox_ops.py b/projects/centerformer/centerformer/utils/bbox_ops.py index fee864ecb9..4d73426cc7 100644 --- a/projects/centerformer/centerformer/utils/bbox_ops.py +++ b/projects/centerformer/centerformer/utils/bbox_ops.py @@ -1,22 +1,24 @@ import torch -from . import iou3d_nms_cuda +from mmcv.ops import nms3d -def rotate_nms_pcdet(boxes, - scores, - thresh, - pre_maxsize=None, - post_max_size=None): - """ - :param boxes: (N, 5) [x, y, z, l, w, h, theta] - :param scores: (N) - :param thresh: - :return: +def nms_iou3d(boxes, scores, thresh, pre_maxsize=None, post_max_size=None): + """NMS function GPU implementation (using IoU3D) + + Args: + boxes (Tensor): Input boxes with the shape of [N, 5] + ([x1, y1, x2, y2, ry]). + scores (Tensor): Scores of boxes with the shape of [N]. + thresh (float): Overlap threshold of NMS. + pre_max_size (int, optional): Max size of boxes before NMS. + Defaults to None. + post_max_size (int, optional): Max size of boxes after NMS. + Defaults to None. + + Returns: + Tensor: Indexes after NMS. """ - # transform back to pcdet's coordinate - # boxes = boxes[:, [0, 1, 2, 4, 3, 5, -1]] - # boxes[:, -1] = -boxes[:, -1] - np.pi / 2 order = scores.sort(0, descending=True)[1] if pre_maxsize is not None: @@ -29,48 +31,11 @@ def rotate_nms_pcdet(boxes, if len(boxes) == 0: num_out = 0 else: - num_out = iou3d_nms_cuda.nms_gpu(boxes, keep, thresh) + num_out = nms3d(boxes, keep, thresh) - selected = order[keep[:num_out].cuda()].contiguous() + selected = order[keep[:num_out].to(scores.device())].contiguous() if post_max_size is not None: selected = selected[:post_max_size] return selected - - -def boxes_iou3d_gpu_pcdet(boxes_a, boxes_b): - """ - Args: - boxes_a: (N, 7) [x, y, z, dx, dy, dz, heading] - boxes_b: (N, 7) [x, y, z, dx, dy, dz, heading] - Returns: - ans_iou: (N, M) - """ - assert boxes_a.shape[1] == boxes_b.shape[1] == 7 - - # height overlap - boxes_a_height_max = (boxes_a[:, 2] + boxes_a[:, 5] / 2).view(-1, 1) - boxes_a_height_min = (boxes_a[:, 2] - boxes_a[:, 5] / 2).view(-1, 1) - boxes_b_height_max = (boxes_b[:, 2] + boxes_b[:, 5] / 2).view(1, -1) - boxes_b_height_min = (boxes_b[:, 2] - boxes_b[:, 5] / 2).view(1, -1) - - # bev overlap - overlaps_bev = torch.cuda.FloatTensor( - torch.Size((boxes_a.shape[0], boxes_b.shape[0]))).zero_() # (N, M) - iou3d_nms_cuda.boxes_overlap_bev_gpu(boxes_a.contiguous(), - boxes_b.contiguous(), overlaps_bev) - - max_of_min = torch.max(boxes_a_height_min, boxes_b_height_min) - min_of_max = torch.min(boxes_a_height_max, boxes_b_height_max) - overlaps_h = torch.clamp(min_of_max - max_of_min, min=0) - - # 3d iou - overlaps_3d = overlaps_bev * overlaps_h - - vol_a = (boxes_a[:, 3] * boxes_a[:, 4] * boxes_a[:, 5]).view(-1, 1) - vol_b = (boxes_b[:, 3] * boxes_b[:, 4] * boxes_b[:, 5]).view(1, -1) - - iou3d = overlaps_3d / torch.clamp(vol_a + vol_b - overlaps_3d, min=1e-6) - - return iou3d diff --git a/tools/train.py b/tools/train.py index fe8ccbfa53..a31de28fb2 100644 --- a/tools/train.py +++ b/tools/train.py @@ -25,7 +25,7 @@ def parse_args(): '--disable-tf32', action='store_true', default=False, - help='disable TF32 in A100 GPUs') + help='disable TF32 on A100 GPUs') parser.add_argument( '--auto-scale-lr', action='store_true', @@ -98,6 +98,12 @@ def main(): f'`OptimWrapper` but got {optim_wrapper}.') cfg.optim_wrapper.type = 'AmpOptimWrapper' cfg.optim_wrapper.loss_scale = 'dynamic' + + if args.disable_tf32: + import torch + torch.backends.cuda.matmul.allow_tf32 = False + torch.backends.cudnn.allow_tf32 = False + print_log('Disable TF32 on A100 GPUs', logger='current') # enable automatically scaling LR if args.auto_scale_lr: @@ -121,11 +127,6 @@ def main(): # if 'runner_type' is set in the cfg runner = RUNNERS.build(cfg) - if args.disable_tf32: - import torch - torch.backends.cuda.matmul.allow_tf32 = False - torch.backends.cudnn.allow_tf32 = False - # start training runner.train() From a593b9146da97267ebf7b03f1c0c6abf5cafd343 Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Mon, 26 Dec 2022 16:46:41 +0800 Subject: [PATCH 04/18] remove attention.py and sparse_block.py --- .../centerformer/utils/attention.py | 42 --------- .../centerformer/utils/sparse_block.py | 88 ------------------- ...-atten_4xb4-cyclic-20e_waymoD5-3d-class.py | 4 +- 3 files changed, 3 insertions(+), 131 deletions(-) delete mode 100644 projects/centerformer/centerformer/utils/attention.py delete mode 100644 projects/centerformer/centerformer/utils/sparse_block.py diff --git a/projects/centerformer/centerformer/utils/attention.py b/projects/centerformer/centerformer/utils/attention.py deleted file mode 100644 index fa2b7d765a..0000000000 --- a/projects/centerformer/centerformer/utils/attention.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch -from mmengine.model import BaseModule -from torch import nn - - -class ChannelAttention(BaseModule): - - def __init__(self, in_planes, ratio=16): - super(ChannelAttention, self).__init__() - self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.max_pool = nn.AdaptiveMaxPool2d(1) - - self.fc = nn.Sequential( - nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False), - nn.ReLU(), - nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False), - ) - self.sigmoid = nn.Sigmoid() - - def forward(self, x): - avg_out = self.fc(self.avg_pool(x)) - max_out = self.fc(self.max_pool(x)) - out = avg_out + max_out - return self.sigmoid(out) * x - - -class SpatialAttention(BaseModule): - - def __init__(self, kernel_size=7): - super(SpatialAttention, self).__init__() - - self.conv1 = nn.Conv2d( - 2, 1, kernel_size, padding=kernel_size // 2, bias=False) - self.sigmoid = nn.Sigmoid() - - def forward(self, x): - avg_out = torch.mean(x, dim=1, keepdim=True) - max_out, _ = torch.max(x, dim=1, keepdim=True) - y = torch.cat([avg_out, max_out], dim=1) - y = self.conv1(y) - return self.sigmoid(y) * x diff --git a/projects/centerformer/centerformer/utils/sparse_block.py b/projects/centerformer/centerformer/utils/sparse_block.py deleted file mode 100644 index 6d9221aeab..0000000000 --- a/projects/centerformer/centerformer/utils/sparse_block.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) OpenMMLab. All rights reserved. -import torch.utils.checkpoint as cp -from mmcv.cnn import build_conv_layer, build_norm_layer -from mmengine.model import BaseModule -from torch import nn - - -class BasicBlockBias(BaseModule): - expansion = 1 - - def __init__(self, - inplanes, - planes, - stride=1, - dilation=1, - downsample=None, - style='pytorch', - with_cp=False, - conv_cfg=None, - norm_cfg=dict(type='BN'), - dcn=None, - plugins=None, - init_cfg=None): - super(BasicBlockBias, self).__init__(init_cfg) - assert dcn is None, 'Not implemented yet.' - assert plugins is None, 'Not implemented yet.' - - self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) - self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) - - self.conv1 = build_conv_layer( - conv_cfg, - inplanes, - planes, - 3, - stride=stride, - padding=dilation, - dilation=dilation, - bias=True) - self.add_module(self.norm1_name, norm1) - self.conv2 = build_conv_layer( - conv_cfg, planes, planes, 3, padding=1, bias=True) - self.add_module(self.norm2_name, norm2) - - self.relu = nn.ReLU(inplace=True) - self.downsample = downsample - self.stride = stride - self.dilation = dilation - self.with_cp = with_cp - - @property - def norm1(self): - """nn.Module: normalization layer after the first convolution layer""" - return getattr(self, self.norm1_name) - - @property - def norm2(self): - """nn.Module: normalization layer after the second convolution layer""" - return getattr(self, self.norm2_name) - - def forward(self, x): - """Forward function.""" - - def _inner_forward(x): - identity = x - - out = self.conv1(x) - out = self.norm1(out) - out = self.relu(out) - - out = self.conv2(out) - out = self.norm2(out) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - - return out - - if self.with_cp and x.requires_grad: - out = cp.checkpoint(_inner_forward, x) - else: - out = _inner_forward(x) - - out = self.relu(out) - - return out diff --git a/projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py b/projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py index f71230f2f9..0dff8feaa2 100644 --- a/projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py +++ b/projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py @@ -1,4 +1,6 @@ _base_ = ['mmdet3d::_base_/default_runtime.py'] +custom_imports = dict( + imports=['projects.centerformer.centerformer'], allow_failed_imports=False) # model settings # Voxel size for voxel encoder @@ -104,7 +106,7 @@ obj_num=1000, )) -data_root = 'data/waymo/kitti_format/' +data_root = 'data/waymo_mini/kitti_format/' db_sampler = dict( data_root=data_root, info_path=data_root + 'waymo_dbinfos_train.pkl', From 2b9b347bbc0fa711a67bf08f8caa3ff8200b42d9 Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Mon, 26 Dec 2022 20:01:35 +0800 Subject: [PATCH 05/18] only using single fold --- projects/centerformer/centerformer/__init__.py | 6 ++++-- .../centerformer/{utils => }/bbox_ops.py | 0 ...transformer.py => centerformer_backbone.py} | 6 +++++- .../centerformer/centerformer_head.py | 2 +- .../{utils => }/multi_scale_deform_attn.py | 18 +++++++++++------- .../centerformer/{utils => }/transformer.py | 6 ++++++ .../centerformer/utils/__init__.py | 4 ---- ...n-atten_4xb4-cyclic-20e_waymoD5-3d-class.py | 2 +- 8 files changed, 28 insertions(+), 16 deletions(-) rename projects/centerformer/centerformer/{utils => }/bbox_ops.py (100%) rename projects/centerformer/centerformer/{rpn_transformer.py => centerformer_backbone.py} (99%) rename projects/centerformer/centerformer/{utils => }/multi_scale_deform_attn.py (96%) rename projects/centerformer/centerformer/{utils => }/transformer.py (98%) delete mode 100644 projects/centerformer/centerformer/utils/__init__.py diff --git a/projects/centerformer/centerformer/__init__.py b/projects/centerformer/centerformer/__init__.py index 5893950e9b..74f7657fb8 100644 --- a/projects/centerformer/centerformer/__init__.py +++ b/projects/centerformer/centerformer/__init__.py @@ -1,9 +1,11 @@ from .centerformer import CenterFormer from .centerformer_head import CenterHeadIoU_1d from .losses import FastFocalLoss -from .rpn_transformer import RPN_transformer_deformable +from .centerformer_backbone import (RPN_transformer_deformable, + RPN_transformer_deformable_mtf) +from .bbox_ops import nms_iou3d __all__ = [ 'CenterFormer', 'RPN_transformer_deformable', 'CenterHeadIoU_1d', - 'FastFocalLoss' + 'FastFocalLoss', 'nms_iou3d', 'RPN_transformer_deformable_mtf' ] diff --git a/projects/centerformer/centerformer/utils/bbox_ops.py b/projects/centerformer/centerformer/bbox_ops.py similarity index 100% rename from projects/centerformer/centerformer/utils/bbox_ops.py rename to projects/centerformer/centerformer/bbox_ops.py diff --git a/projects/centerformer/centerformer/rpn_transformer.py b/projects/centerformer/centerformer/centerformer_backbone.py similarity index 99% rename from projects/centerformer/centerformer/rpn_transformer.py rename to projects/centerformer/centerformer/centerformer_backbone.py index 06eedd629f..7b14273530 100644 --- a/projects/centerformer/centerformer/rpn_transformer.py +++ b/projects/centerformer/centerformer/centerformer_backbone.py @@ -12,7 +12,7 @@ from mmdet3d.models.utils import draw_heatmap_gaussian, gaussian_radius from mmdet3d.registry import MODELS from mmdet3d.structures import center_to_corner_box2d -from .utils import Deform_Transformer +from .transformer import Deform_Transformer class ChannelAttention(nn.Module): @@ -334,6 +334,10 @@ def get_multi_scale_feature_multiframe(self, center_pos, feats, timeframe): @MODELS.register_module() class RPN_transformer_deformable(RPN_transformer_base): + '''The original implement of CenterFormer modules. It fusion the backbone + neck and heatmap head into one module. + + ''' def __init__( self, diff --git a/projects/centerformer/centerformer/centerformer_head.py b/projects/centerformer/centerformer/centerformer_head.py index f5c197c7a1..039bb7a8a9 100644 --- a/projects/centerformer/centerformer/centerformer_head.py +++ b/projects/centerformer/centerformer/centerformer_head.py @@ -20,7 +20,7 @@ from mmdet3d.registry import MODELS from .losses import FastFocalLoss from mmcv.ops import boxes_iou3d -from .utils import nms_iou3d +from .bbox_ops import nms_iou3d class SepHead(nn.Module): diff --git a/projects/centerformer/centerformer/utils/multi_scale_deform_attn.py b/projects/centerformer/centerformer/multi_scale_deform_attn.py similarity index 96% rename from projects/centerformer/centerformer/utils/multi_scale_deform_attn.py rename to projects/centerformer/centerformer/multi_scale_deform_attn.py index ca2467ac88..7417462827 100644 --- a/projects/centerformer/centerformer/utils/multi_scale_deform_attn.py +++ b/projects/centerformer/centerformer/multi_scale_deform_attn.py @@ -85,6 +85,17 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple: class MSDeformAttn(nn.Module): + """Multi-Scale Deformable Attention Module. + Note that the difference between this implementation and the implementation + in MMCV is that the dimension of input and hidden embedding in the + multi-attention-head can be specified respectively. + + :param d_model input dimension + :param d_head hidden dimension + :param n_levels number of feature levels + :param n_heads number of attention heads + :param n_points number of sampling points per attention head per feature level # noqa: E501 + """ def __init__(self, d_model=256, @@ -93,13 +104,6 @@ def __init__(self, n_heads=8, n_points=4, out_sample_loc=False): - """Multi-Scale Deformable Attention Module. - - :param d_model hidden dimension - :param n_levels number of feature levels - :param n_heads number of attention heads - :param n_points number of sampling points per attention head per feature level # noqa: E501 - """ super().__init__() self.im2col_step = 64 diff --git a/projects/centerformer/centerformer/utils/transformer.py b/projects/centerformer/centerformer/transformer.py similarity index 98% rename from projects/centerformer/centerformer/utils/transformer.py rename to projects/centerformer/centerformer/transformer.py index f7b06e186c..2b231091c1 100644 --- a/projects/centerformer/centerformer/utils/transformer.py +++ b/projects/centerformer/centerformer/transformer.py @@ -304,6 +304,12 @@ def forward(self, class Deform_Transformer(nn.Module): + '''Deformable transformer. + Note that the difference between this implementation and the implementation + in MMDet is that the dimension of input and hidden embedding in the + multi-attention-head can be specified respectively. + + ''' def __init__( self, diff --git a/projects/centerformer/centerformer/utils/__init__.py b/projects/centerformer/centerformer/utils/__init__.py deleted file mode 100644 index a7f122aea7..0000000000 --- a/projects/centerformer/centerformer/utils/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .bbox_ops import nms_iou3d -from .transformer import Deform_Transformer - -__all__ = ['Deform_Transformer', 'nms_iou3d'] diff --git a/projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py b/projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py index 0dff8feaa2..103fcdd1f8 100644 --- a/projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py +++ b/projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py @@ -106,7 +106,7 @@ obj_num=1000, )) -data_root = 'data/waymo_mini/kitti_format/' +data_root = 'data/waymo/kitti_format/' db_sampler = dict( data_root=data_root, info_path=data_root + 'waymo_dbinfos_train.pkl', From 0ed875f6140d1a4ae2e3ce7747a08bcb09deb519 Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Tue, 27 Dec 2022 17:38:10 +0800 Subject: [PATCH 06/18] polish code --- .../hooks/disable_object_sample_hook.py | 2 +- .../centerformer/centerformer/__init__.py | 12 +- .../centerformer/centerformer/bbox_ops.py | 28 +-- .../centerformer/centerformer_backbone.py | 108 +++++---- .../centerformer/centerformer_head.py | 37 +-- .../centerformer/multi_scale_deform_attn.py | 80 ++++--- .../centerformer/centerformer/transformer.py | 214 ++---------------- ...-attn_4xb4-cyclic-20e_waymoD5-3d-class.py} | 14 +- tools/train.py | 2 +- 9 files changed, 163 insertions(+), 334 deletions(-) rename projects/centerformer/configs/{centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py => centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py} (97%) diff --git a/mmdet3d/engine/hooks/disable_object_sample_hook.py b/mmdet3d/engine/hooks/disable_object_sample_hook.py index 293e18d105..3a8182c45e 100644 --- a/mmdet3d/engine/hooks/disable_object_sample_hook.py +++ b/mmdet3d/engine/hooks/disable_object_sample_hook.py @@ -36,7 +36,7 @@ def before_train_epoch(self, runner: Runner): model = model.module if epoch == self.disable_after_epoch: runner.logger.info('Disable ObjectSample') - for transform in runner.train_dataloader.dataset.pipeline.transforms: + for transform in runner.train_dataloader.dataset.pipeline.transforms: # noqa: E501 if isinstance(transform, ObjectSample): assert hasattr(transform, 'disabled') transform.disabled = True diff --git a/projects/centerformer/centerformer/__init__.py b/projects/centerformer/centerformer/__init__.py index 74f7657fb8..3bd38cd670 100644 --- a/projects/centerformer/centerformer/__init__.py +++ b/projects/centerformer/centerformer/__init__.py @@ -1,11 +1,11 @@ +from .bbox_ops import nms_iou3d from .centerformer import CenterFormer -from .centerformer_head import CenterHeadIoU_1d +from .centerformer_backbone import (DeformableDecoderRPN, + MultiFrameDeformableDecoderRPN) +from .centerformer_head import CenterFormerBboxHead from .losses import FastFocalLoss -from .centerformer_backbone import (RPN_transformer_deformable, - RPN_transformer_deformable_mtf) -from .bbox_ops import nms_iou3d __all__ = [ - 'CenterFormer', 'RPN_transformer_deformable', 'CenterHeadIoU_1d', - 'FastFocalLoss', 'nms_iou3d', 'RPN_transformer_deformable_mtf' + 'CenterFormer', 'DeformableDecoderRPN', 'CenterFormerBboxHead', + 'FastFocalLoss', 'nms_iou3d', 'MultiFrameDeformableDecoderRPN' ] diff --git a/projects/centerformer/centerformer/bbox_ops.py b/projects/centerformer/centerformer/bbox_ops.py index 4d73426cc7..d850eaf8ee 100644 --- a/projects/centerformer/centerformer/bbox_ops.py +++ b/projects/centerformer/centerformer/bbox_ops.py @@ -1,10 +1,13 @@ import torch +from mmcv.utils import ext_loader -from mmcv.ops import nms3d +ext_module = ext_loader.load_ext('_ext', ['iou3d_nms3d_forward']) def nms_iou3d(boxes, scores, thresh, pre_maxsize=None, post_max_size=None): - """NMS function GPU implementation (using IoU3D) + """NMS function GPU implementation (using IoU3D) The difference of this + implementation with nms3d in MMCV is that we add `pre_maxsize` and + `post_max_size` before and after NMS respectively. Args: boxes (Tensor): Input boxes with the shape of [N, 5] @@ -19,23 +22,20 @@ def nms_iou3d(boxes, scores, thresh, pre_maxsize=None, post_max_size=None): Returns: Tensor: Indexes after NMS. """ - + # TODO: directly refactor ``nms3d`` in MMCV + assert boxes.size(1) == 7, 'Input boxes shape should be (N, 7)' order = scores.sort(0, descending=True)[1] if pre_maxsize is not None: order = order[:pre_maxsize] - boxes = boxes[order].contiguous() - keep = torch.LongTensor(boxes.size(0)) - - if len(boxes) == 0: - num_out = 0 - else: - num_out = nms3d(boxes, keep, thresh) - - selected = order[keep[:num_out].to(scores.device())].contiguous() + keep = boxes.new_zeros(boxes.size(0), dtype=torch.long) + num_out = boxes.new_zeros(size=(), dtype=torch.long) + ext_module.iou3d_nms3d_forward( + boxes, keep, num_out, nms_overlap_thresh=thresh) + keep = order[keep[:num_out].to(boxes.device)].contiguous() if post_max_size is not None: - selected = selected[:post_max_size] + keep = keep[:post_max_size] - return selected + return keep diff --git a/projects/centerformer/centerformer/centerformer_backbone.py b/projects/centerformer/centerformer/centerformer_backbone.py index 7b14273530..5bef2cd6fe 100644 --- a/projects/centerformer/centerformer/centerformer_backbone.py +++ b/projects/centerformer/centerformer/centerformer_backbone.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple +from typing import List, Tuple import numpy as np import torch @@ -7,12 +7,11 @@ from mmengine.logging import print_log from mmengine.structures import InstanceData from torch import Tensor, nn -from torch.nn import functional as F from mmdet3d.models.utils import draw_heatmap_gaussian, gaussian_radius from mmdet3d.registry import MODELS from mmdet3d.structures import center_to_corner_box2d -from .transformer import Deform_Transformer +from .transformer import DeformableTransformer class ChannelAttention(nn.Module): @@ -53,10 +52,10 @@ def forward(self, x): return self.sigmoid(y) * x -class SpatialAttention_mtf(nn.Module): +class MultiFrameSpatialAttention(nn.Module): def __init__(self, kernel_size=7): - super(SpatialAttention_mtf, self).__init__() + super(MultiFrameSpatialAttention, self).__init__() self.conv1 = nn.Conv2d( 2, 1, kernel_size, padding=kernel_size // 2, bias=False) @@ -70,7 +69,7 @@ def forward(self, curr, prev): return self.sigmoid(y) * prev -class RPN_transformer_base(nn.Module): +class BaseDecoderRPN(nn.Module): def __init__( self, @@ -90,7 +89,7 @@ def __init__( score_threshold=0.1, obj_num=500, **kwargs): - super(RPN_transformer_base, self).__init__() + super(BaseDecoderRPN, self).__init__() self._layer_strides = [1, 2, -4] self._num_filters = ds_num_filters self._layer_nums = layer_nums @@ -246,7 +245,6 @@ def get_multi_scale_feature(self, center_pos, feats): feat_id = (neighbor_coords[:, :, :, 1] * (W // (2**i)) + neighbor_coords[:, :, :, 0]) # pixel id [B, 500, k] feat_id = feat_id.reshape(batch, -1) # pixel id [B, 500*k] - # selected_feat = torch.gather(feats[i].reshape(batch, num_cls,(H*W)//(4**i)).permute(0, 2, 1).contiguous(),1,feat_id) selected_feat = ( feats[i].reshape(batch, num_cls, (H * W) // (4**i)).permute( 0, 2, 1).contiguous()[self.batch_id.repeat(1, k**2), @@ -255,7 +253,6 @@ def get_multi_scale_feature(self, center_pos, feats): selected_feat.reshape(batch, center_num, -1, num_cls)) # B, 500, k, C relative_pos_list.append(neighbor_coords * (2**i)) # B, 500, k, 2 - # relative_pos_list.append(F.pad(neighbor_coords*(2**i), (0,1), "constant", i)) # B, 500, k, 3 neighbor_pos = torch.cat(relative_pos_list, dim=2) # B, 500, K, 2/3 neighbor_feats = torch.cat(neighbor_feat_list, dim=2) # B, 500, K, C @@ -333,35 +330,36 @@ def get_multi_scale_feature_multiframe(self, center_pos, feats, timeframe): @MODELS.register_module() -class RPN_transformer_deformable(RPN_transformer_base): - '''The original implement of CenterFormer modules. It fusion the backbone - neck and heatmap head into one module. - - ''' - - def __init__( - self, - layer_nums, # [2,2,2] - ds_num_filters, # [128,256,64] - num_input_features, - tasks=dict(), - transformer_config=None, - hm_head_layer=2, - corner_head_layer=2, - corner=False, - parametric_embedding=False, - assign_label_window_size=1, - classes=3, - use_gt_training=False, - norm_cfg=None, - logger=None, - init_bias=-2.19, - score_threshold=0.1, - obj_num=500, - train_cfg=None, - test_cfg=None, - **kwargs): - super(RPN_transformer_deformable, self).__init__( +class DeformableDecoderRPN(BaseDecoderRPN): + """The original implement of CenterFormer modules. + + It fuse the backbone neck and heatmap head into one module. + + TODO: split this module into backbone、neck and head. + """ + + def __init__(self, + layer_nums, + ds_num_filters, + num_input_features, + tasks=dict(), + transformer_config=None, + hm_head_layer=2, + corner_head_layer=2, + corner=False, + parametric_embedding=False, + assign_label_window_size=1, + classes=3, + use_gt_training=False, + norm_cfg=None, + logger=None, + init_bias=-2.19, + score_threshold=0.1, + obj_num=500, + train_cfg=None, + test_cfg=None, + **kwargs): + super(DeformableDecoderRPN, self).__init__( layer_nums, ds_num_filters, num_input_features, @@ -381,16 +379,15 @@ def __init__( self.train_cfg = train_cfg self.test_cfg = test_cfg self.tasks = tasks - num_classes = [len(t['class_names']) for t in tasks] self.class_names = [t['class_names'] for t in tasks] - self.transformer_layer = Deform_Transformer( + self.transformer_layer = DeformableTransformer( self._num_filters[-1] * 2, depth=transformer_config.depth, heads=transformer_config.heads, dim_head=transformer_config.dim_head, - mlp_dim=transformer_config.MLP_dim, - dropout=transformer_config.DP_rate, + dim_ffn=transformer_config.dim_ffn, + dropout=transformer_config.dropout_rate, out_attention=transformer_config.out_att, n_points=transformer_config.get('n_points', 9), ) @@ -467,11 +464,8 @@ def forward(self, x, batch_data_samples): labels = torch.gather(labels, 1, order) mask = scores > self.score_threshold - ct_feat = (x_up.reshape(batch, -1, - H * W).transpose(2, - 1).contiguous()[self.batch_id, - order] - ) # B, 500, C + ct_feat = x_up.reshape(batch, -1, H * W).transpose(2, 1).contiguous() + ct_feat = ct_feat[self.batch_id, order] # B, 500, C # create position embedding for each center y_coor = order // W @@ -548,7 +542,7 @@ def get_targets( [ tensor0, tensor1, tensor2, ... ] Args: batch_gt_instances_3d (list[:obj:`InstanceData`]): Batch of - gt_instances. It usually includes ``bboxes_3d`` and\ + gt_instances. It usually includes ``bboxes_3d`` and ``labels_3d`` attributes. Returns: Returns: @@ -560,8 +554,9 @@ def get_targets( position of the valid boxes. - list[torch.Tensor]: Masks indicating which boxes are valid. + - list[torch.Tensor]: catagrate labels. """ - heatmaps, anno_boxes, inds, masks, corner_heatmaps, cat_labels = multi_apply( + heatmaps, anno_boxes, inds, masks, corner_heatmaps, cat_labels = multi_apply( # noqa: E501 self.get_targets_single, batch_gt_instances_3d) # Transpose heatmaps heatmaps = list(map(list, zip(*heatmaps))) @@ -599,6 +594,7 @@ def get_targets_single(self, of the valid boxes. - list[torch.Tensor]: Masks indicating which boxes are valid. + - list[torch.Tensor]: catagrate labels. """ gt_labels_3d = gt_instances_3d.labels_3d gt_bboxes_3d = gt_instances_3d.bboxes_3d @@ -637,7 +633,7 @@ def get_targets_single(self, task_classes.append(torch.cat(task_class).long().to(device)) flag2 += len(mask) draw_gaussian = draw_heatmap_gaussian - heatmaps, anno_boxes, inds, masks, corner_heatmaps, cat_labels = [], [], [], [], [], [] + heatmaps, anno_boxes, inds, masks, corner_heatmaps, cat_labels = [], [], [], [], [], [] # noqa: E501 for idx in range(len(self.tasks)): heatmap = gt_bboxes_3d.new_zeros( @@ -759,7 +755,7 @@ def get_targets_single(self, @MODELS.register_module() -class RPN_transformer_deformable_mtf(RPN_transformer_base): +class MultiFrameDeformableDecoderRPN(BaseDecoderRPN): def __init__( self, @@ -781,7 +777,7 @@ def __init__( obj_num=500, frame=1, **kwargs): - super(RPN_transformer_deformable_mtf, self).__init__( + super(MultiFrameDeformableDecoderRPN, self).__init__( layer_nums, ds_num_filters, num_input_features, @@ -811,17 +807,17 @@ def __init__( build_norm_layer(self._norm_cfg, self._num_filters[0])[1], nn.ReLU(), ) - self.mtf_attention = SpatialAttention_mtf() + self.mtf_attention = MultiFrameSpatialAttention() self.time_embedding = nn.Linear(1, self._num_filters[0]) - self.transformer_layer = Deform_Transformer( + self.transformer_layer = DeformableTransformer( self._num_filters[-1] * 2, depth=transformer_config.depth, heads=transformer_config.heads, levels=2 + self.frame, dim_head=transformer_config.dim_head, - mlp_dim=transformer_config.MLP_dim, - dropout=transformer_config.DP_rate, + dim_ffn=transformer_config.dim_ffn, + dropout=transformer_config.dropout_rate, out_attention=transformer_config.out_att, n_points=transformer_config.get('n_points', 9), ) diff --git a/projects/centerformer/centerformer/centerformer_head.py b/projects/centerformer/centerformer/centerformer_head.py index 039bb7a8a9..051e298df1 100644 --- a/projects/centerformer/centerformer/centerformer_head.py +++ b/projects/centerformer/centerformer/centerformer_head.py @@ -11,6 +11,7 @@ import numpy as np import torch from mmcv.cnn import build_norm_layer +from mmcv.ops import boxes_iou3d from mmengine.logging import print_log from mmengine.model import kaiming_init from mmengine.structures import InstanceData @@ -18,12 +19,15 @@ from mmdet3d.models.layers import circle_nms, nms_bev from mmdet3d.registry import MODELS -from .losses import FastFocalLoss -from mmcv.ops import boxes_iou3d from .bbox_ops import nms_iou3d +from .losses import FastFocalLoss class SepHead(nn.Module): + """TODO: This module is the original implementation in CenterFormer and it + has few differences with ``SeperateHead`` in `mmdet3d` but refactor this + module will lower the performance a little. + """ def __init__( self, @@ -85,13 +89,11 @@ def forward(self, x, y): @MODELS.register_module() -class CenterHeadIoU_1d(nn.Module): +class CenterFormerBboxHead(nn.Module): def __init__(self, - in_channels=[ - 128, - ], - tasks=[], + in_channels, + tasks, weight=0.25, iou_weight=1, corner_weight=1, @@ -108,7 +110,7 @@ def __init__(self, bbox_code_size=7, test_cfg=None, **kawrgs): - super(CenterHeadIoU_1d, self).__init__() + super(CenterFormerBboxHead, self).__init__() num_classes = [len(t['class_names']) for t in tasks] self.class_names = [t['class_names'] for t in tasks] @@ -136,7 +138,7 @@ def __init__(self, self.use_direction_classifier = False if not logger: - logger = logging.getLogger('CenterHeadIoU_1d') + logger = logging.getLogger('CenterFormerBboxHead') self.logger = logger logger.info(f'num_classes: {num_classes}') @@ -390,12 +392,6 @@ def predict(self, preds_dicts, batch_input_metas, **kwargs): if k == 'bboxes': bboxes = torch.cat([ret[i][k] for ret in rets]) bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 5] * 0.5 - # The original CenterFormer model predict (..., w,l,h) - # Note that this is used to align the precision of - # converted model - # bboxes[:, 4], bboxes[:, 3] = bboxes[:, 3].clone( - # ), bboxes[:, 4].clone() - # bboxes[:, 6] = -bboxes[:, 6] - np.pi / 2 bboxes = batch_input_metas[i]['box_type_3d']( bboxes, self.bbox_code_size) elif k == 'labels': @@ -406,7 +402,6 @@ def predict(self, preds_dicts, batch_input_metas, **kwargs): labels = torch.cat([ret[i][k] for ret in rets]) elif k == 'scores': scores = torch.cat([ret[i][k] for ret in rets]) - # ret['metadata'] = metas[0][i] temp_instances.bboxes_3d = bboxes temp_instances.scores_3d = scores @@ -469,16 +464,6 @@ def post_processing( class_mask = labels == c if class_mask.sum() > 0: class_idx = class_mask.nonzero() - # boxes_for_nms = xywhr2xyxyr( - # img_metas[i]['box_type_3d']( - # box_preds[:, :], self.bbox_code_size).bev) - # select = nms_bev( - # boxes_for_nms[class_mask].float(), - # scores[class_mask].float(), - # thresh=test_cfg.nms.nms_iou_threshold[c], - # pre_max_size=test_cfg.nms.nms_pre_max_size[c], - # post_max_size=test_cfg.nms.nms_post_max_size[c], - # ) select = nms_iou3d( boxes_for_nms[class_mask].float(), scores[class_mask].float(), diff --git a/projects/centerformer/centerformer/multi_scale_deform_attn.py b/projects/centerformer/centerformer/multi_scale_deform_attn.py index 7417462827..ec79dc8e3c 100644 --- a/projects/centerformer/centerformer/multi_scale_deform_attn.py +++ b/projects/centerformer/centerformer/multi_scale_deform_attn.py @@ -1,9 +1,10 @@ import math +from typing import Optional import torch import torch.nn.functional as F from mmcv.utils import ext_loader -from torch import nn +from torch import Tensor, nn from torch.autograd.function import Function, once_differentiable from torch.nn.init import constant_, xavier_uniform_ @@ -85,17 +86,21 @@ def backward(ctx, grad_output: torch.Tensor) -> tuple: class MSDeformAttn(nn.Module): - """Multi-Scale Deformable Attention Module. - Note that the difference between this implementation and the implementation - in MMCV is that the dimension of input and hidden embedding in the - multi-attention-head can be specified respectively. - - :param d_model input dimension - :param d_head hidden dimension - :param n_levels number of feature levels - :param n_heads number of attention heads - :param n_points number of sampling points per attention head per feature level # noqa: E501 - """ + """Multi-Scale Deformable Attention Module. Note that the difference + between this implementation and the implementation in MMCV is that the + dimension of input and hidden embedding in the multi-attention-head can be + specified respectively. + + Args: + d_model (int, optional): input dimension. Defaults to 256. + d_head (int, optional): hidden dimension. Defaults to 64. + n_levels (int, optional): number of feature levels. Defaults to 4. + n_heads (int, optional): number of attention heads. Defaults to 8. + n_points (int, optional): number of sampling points per attention head + per feature level. Defaults to 4. + out_sample_loc (bool, optional): Whether to return the sampling + location. Defaults to False. + """ def __init__(self, d_model=256, @@ -146,23 +151,36 @@ def _reset_parameters(self): constant_(self.output_proj.bias.data, 0.) def forward(self, - query, - reference_points, - input_flatten, - input_spatial_shapes, - input_level_start_index, - input_padding_mask=None): - """ - :param query: (N, Length_{query}, C) - :param reference_points - (N, Length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area # noqa: E501 - or (N, Length_{query}, n_levels, 4), add additional (w, h) to form reference boxes # noqa: E501 - :param input_flatten : (N, \sum_{l=0}^{L-1} H_l \cdot W_l, C) - :param input_spatial_shapes: - (n_levels, 2), [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] - :param input_level_start_index: (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, ..., H_0*W_0+H_1*W_1+...+H_{L-1}*W_{L-1}] # noqa: E501 - :param input_padding_mask (N, \sum_{l=0}^{L-1} H_l \cdot W_l), True for padding elements, False for non-padding elements # noqa: E501 - :return output (N, Length_{query}, C) + query: Tensor, + reference_points: Tensor, + input_flatten: Tensor, + input_spatial_shapes: Tensor, + input_level_start_index: Tensor, + input_padding_mask: Optional[Tensor] = None): + """Forward Function of MultiScaleDeformAttention. + + Args: + query (Tensor): (N, num_query, C) + reference_points (Tensor): (N, num_query, n_levels, 2). The + normalized reference points with shape + (bs, num_query, num_levels, 2), + all elements is range in [0, 1], top-left (0,0), + bottom-right (1, 1), including padding area. + or (N, Length_{query}, num_levels, 4), add + additional two dimensions is (w, h) to + form reference boxes. + input_flatten (Tensor): _description_ + input_spatial_shapes (Tensor): Spatial shape of features in + different levels. With shape (num_levels, 2), + last dimension represents (h, w). + input_level_start_index (Tensor): The start index of each level. + A tensor has shape ``(num_levels, )`` and can be represented + as [0, h_0*w_0, h_0*w_0+h_1*w_1, ...]. + input_padding_mask (Optional[Tensor], optional): The padding mask + for value. Defaults to None. + + Returns: + Tuple[Tensor, Tensor]: forwarded results. """ N, Len_q, _ = query.shape N, Len_in, _ = input_flatten.shape @@ -186,8 +204,8 @@ def forward(self, [input_spatial_shapes[..., 1], input_spatial_shapes[..., 0]], -1).to(sampling_offsets) - sampling_locations = reference_points[:, :, None, :, None, :] \ - + sampling_offsets / offset_normalizer[None, None, None, :, None, :] # noqa: E501 + sampling_locations = reference_points[:, :, None, :, None, :] + \ + sampling_offsets / offset_normalizer[None, None, None, :, None, :] # noqa: E501 elif reference_points.shape[-1] == 4: sampling_locations = reference_points[:, :, None, :, None, :2] \ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5 # noqa: E501 diff --git a/projects/centerformer/centerformer/transformer.py b/projects/centerformer/centerformer/transformer.py index 2b231091c1..7353c757b7 100644 --- a/projects/centerformer/centerformer/transformer.py +++ b/projects/centerformer/centerformer/transformer.py @@ -1,37 +1,11 @@ -import math - import torch from einops import rearrange +from mmcv.cnn.bricks.activation import GELU from torch import einsum, nn -from torch.nn import functional as F from .multi_scale_deform_attn import MSDeformAttn -class MLP(nn.Module): - """Very simple multi-layer perceptron (also called FFN)""" - - def __init__(self, input_dim, hidden_dim, output_dim, num_layers): - super().__init__() - self.num_layers = num_layers - h = [hidden_dim] * (num_layers - 1) - self.layers = nn.ModuleList( - nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])) - - def forward(self, x): - for i, layer in enumerate(self.layers): - x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x) - return x - - -class GELU(nn.Module): - - def forward(self, x): - return 0.5 * x * (1 + torch.tanh( - math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) - - -# transformer layer class PreNorm(nn.Module): def __init__(self, dim, fn): @@ -39,22 +13,14 @@ def __init__(self, dim, fn): self.norm = nn.LayerNorm(dim) self.fn = fn - def forward(self, x, **kwargs): - return self.fn(self.norm(x), **kwargs) - - -class PreNorm_CA(nn.Module): - - def __init__(self, dim, fn): - super().__init__() - self.norm = nn.LayerNorm(dim) - self.fn = fn - - def forward(self, x, y, **kwargs): - return self.fn(self.norm(x), self.norm(y), **kwargs) + def forward(self, x, y=None, **kwargs): + if y is not None: + return self.fn(self.norm(x), self.norm(y), **kwargs) + else: + return self.fn(self.norm(x), **kwargs) -class FeedForward(nn.Module): +class FFN(nn.Module): def __init__(self, dim, hidden_dim, dropout=0.0): super().__init__() @@ -70,7 +36,7 @@ def forward(self, x): return self.net(x) -class Attention(nn.Module): +class SelfAttention(nn.Module): def __init__(self, dim, @@ -94,7 +60,7 @@ def __init__(self, if project_out else nn.Identity()) def forward(self, x): - b, n, _, h = *x.shape, self.heads + _, _, _, h = *x.shape, self.heads qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) @@ -111,53 +77,7 @@ def forward(self, x): return self.to_out(out) -class Cross_attention(nn.Module): - - def __init__(self, - dim, - heads=8, - dim_head=64, - dropout=0.0, - out_attention=False): - super().__init__() - inner_dim = dim_head * heads - project_out = not (heads == 1 and dim_head == dim) - - self.heads = heads - self.scale = dim_head**-0.5 - self.out_attention = out_attention - - self.attend = nn.Softmax(dim=-1) - self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) - self.to_q = nn.Linear(dim, inner_dim, bias=False) - - self.to_out = ( - nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) - if project_out else nn.Identity()) - - def forward(self, x, y): - b, n, m, _, h = *y.shape, self.heads - q = self.to_q(x) - kv = self.to_kv(y).chunk(2, dim=-1) - q = rearrange(q, 'b n (h d) -> (b n) h 1 d', h=h) - k, v = map(lambda t: rearrange(t, 'b n m (h d) -> (b n) h m d', h=h), - kv) - - dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale - - attn = self.attend(dots) - - out = einsum('b h i j, b h j d -> b h i d', attn, v) - out = rearrange(out, '(b n) h 1 d -> b n (h d)', b=b) - - if self.out_attention: - return self.to_out(out), rearrange( - attn, '(b n) h i j -> b n h (i j)', b=b) - else: - return self.to_out(out) - - -class DeformableTransformerCrossAttention(nn.Module): +class DeformableCrossAttention(nn.Module): def __init__( self, @@ -213,103 +133,15 @@ def forward( return tgt -class Transformer(nn.Module): - - def __init__( - self, - dim, - depth=2, - heads=4, - dim_head=64, - mlp_dim=256, - dropout=0.0, - out_attention=False, - ): - super().__init__() - self.out_attention = out_attention - self.layers = nn.ModuleList([]) - self.depth = depth - - for _ in range(depth): - self.layers.append( - nn.ModuleList([ - PreNorm( - dim, - Attention( - dim, - heads=heads, - dim_head=dim_head, - dropout=dropout, - out_attention=self.out_attention, - ), - ), - PreNorm_CA( - dim, - Cross_attention( - dim, - heads=heads, - dim_head=dim_head, - dropout=dropout, - out_attention=self.out_attention, - ), - ), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)), - ])) - - def forward(self, - x, - pos_embedding=None, - center_pos=None, - y=None, - neighbor_pos=None): - if self.out_attention: - out_cross_attention_list = [] - if center_pos is not None and pos_embedding is not None: - center_pos_embedding = pos_embedding(center_pos) - if neighbor_pos is not None and pos_embedding is not None: - neighbor_pos_embedding = pos_embedding(neighbor_pos) - for i, (self_attn, cross_attn, ff) in enumerate(self.layers): - if self.out_attention: - if pos_embedding is not None: - x_att, self_att = self_attn(x + center_pos_embedding) - x = x_att + x - x_att, cross_att = cross_attn(x + center_pos_embedding, - y + neighbor_pos_embedding) - else: - x_att, self_att = self_attn(x) - x = x_att + x - x_att, cross_att = cross_attn(x, y) - out_cross_attention_list.append(cross_att) - else: - if pos_embedding is not None: - x_att = self_attn(x + center_pos_embedding) - x = x_att + x - x_att = cross_attn(x + center_pos_embedding, - y + neighbor_pos_embedding) - else: - x_att = self_attn(x) - x = x_att + x - x_att = cross_attn(x, y) - - x = x_att + x - x = ff(x) + x - - out_dict = {'ct_feat': x} - if self.out_attention: - out_dict.update({ - 'out_attention': - torch.stack(out_cross_attention_list, dim=2) - }) - return out_dict +class DeformableTransformer(nn.Module): + """Deformable transformer. - -class Deform_Transformer(nn.Module): - '''Deformable transformer. - Note that the difference between this implementation and the implementation - in MMDet is that the dimension of input and hidden embedding in the - multi-attention-head can be specified respectively. - - ''' + Note that the ``DeformableDetrTransformerDecoder`` in MMDet has different + interfaces in multi-head-attention which is customized here. For example, + 'embed_dims' is not a position argument in our customized multi-head-self- + attention, but is required in MMDet. Thus, we can not directly use the + ``DeformableDetrTransformerDecoder`` in MMDET. + """ def __init__( self, @@ -318,7 +150,7 @@ def __init__( depth=2, heads=4, dim_head=32, - mlp_dim=256, + dim_ffn=256, dropout=0.0, out_attention=False, n_points=9, @@ -335,7 +167,7 @@ def __init__( nn.ModuleList([ PreNorm( dim, - Attention( + SelfAttention( dim, heads=heads, dim_head=dim_head, @@ -343,9 +175,9 @@ def __init__( out_attention=self.out_attention, ), ), - PreNorm_CA( + PreNorm( dim, - DeformableTransformerCrossAttention( + DeformableCrossAttention( dim, dim_head, n_levels=levels, @@ -355,7 +187,7 @@ def __init__( out_sample_loc=self.out_attention, ), ), - PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)), + PreNorm(dim, FFN(dim, dim_ffn, dropout=dropout)), ])) def forward(self, x, pos_embedding, src, src_spatial_shapes, diff --git a/projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py b/projects/centerformer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py similarity index 97% rename from projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py rename to projects/centerformer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py index 103fcdd1f8..207bd41abd 100644 --- a/projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py +++ b/projects/centerformer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py @@ -41,7 +41,7 @@ encoder_paddings=((1, 1, 1), (1, 1, 1), (1, 1, [0, 1, 1]), (1, 1)), block_type='basicblock'), backbone=dict( - type='RPN_transformer_deformable', + type='DeformableDecoderRPN', layer_nums=[5, 5, 1], ds_num_filters=[256, 256, 128], num_input_features=256, @@ -55,14 +55,14 @@ depth=2, heads=6, dim_head=64, - MLP_dim=256, - DP_rate=0.3, + dim_ffn=256, + dropout_rate=0.3, out_att=False, n_points=15, ), ), bbox_head=dict( - type='CenterHeadIoU_1d', + type='CenterFormerBboxHead', in_channels=256, tasks=tasks, dataset='waymo', @@ -129,6 +129,7 @@ load_dim=6, use_dim=5, norm_intensity=True), + # Add this if using `MultiFrameDeformableDecoderRPN` # dict( # type='LoadPointsFromMultiSweeps', # sweeps_num=9, @@ -299,9 +300,6 @@ auto_scale_lr = dict(enable=False, base_batch_size=16) default_hooks = dict( - logger=dict( - type='LoggerHook', - interval=50, - ), + logger=dict(type='LoggerHook', interval=50), checkpoint=dict(type='CheckpointHook', interval=5)) custom_hooks = [dict(type='DisableObjectSampleHook', disable_after_epoch=15)] diff --git a/tools/train.py b/tools/train.py index a31de28fb2..b1b1d27b8b 100644 --- a/tools/train.py +++ b/tools/train.py @@ -98,7 +98,7 @@ def main(): f'`OptimWrapper` but got {optim_wrapper}.') cfg.optim_wrapper.type = 'AmpOptimWrapper' cfg.optim_wrapper.loss_scale = 'dynamic' - + if args.disable_tf32: import torch torch.backends.cuda.matmul.allow_tf32 = False From 2e412bd81a901f814b88e63236ab25875165ae60 Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Tue, 27 Dec 2022 18:18:26 +0800 Subject: [PATCH 07/18] polish code and add dosstring --- .../centerformer/centerformer_backbone.py | 39 ++++++++++------ .../centerformer/multi_scale_deform_attn.py | 24 +++++----- .../centerformer/centerformer/transformer.py | 46 +++++++++---------- ...n-attn_4xb4-cyclic-20e_waymoD5-3d-class.py | 10 ++-- 4 files changed, 65 insertions(+), 54 deletions(-) diff --git a/projects/centerformer/centerformer/centerformer_backbone.py b/projects/centerformer/centerformer/centerformer_backbone.py index 5bef2cd6fe..13082a0621 100644 --- a/projects/centerformer/centerformer/centerformer_backbone.py +++ b/projects/centerformer/centerformer/centerformer_backbone.py @@ -11,7 +11,7 @@ from mmdet3d.models.utils import draw_heatmap_gaussian, gaussian_radius from mmdet3d.registry import MODELS from mmdet3d.structures import center_to_corner_box2d -from .transformer import DeformableTransformer +from .transformer import DeformableTransformerDecoder class ChannelAttention(nn.Module): @@ -333,7 +333,8 @@ def get_multi_scale_feature_multiframe(self, center_pos, feats, timeframe): class DeformableDecoderRPN(BaseDecoderRPN): """The original implement of CenterFormer modules. - It fuse the backbone neck and heatmap head into one module. + It fuse the backbone, neck and heatmap head into one module. The backbone + is `SECOND` with attention and the neck is `SECONDFPN` with attention. TODO: split this module into backbone、neck and head. """ @@ -381,14 +382,14 @@ def __init__(self, self.tasks = tasks self.class_names = [t['class_names'] for t in tasks] - self.transformer_layer = DeformableTransformer( + self.transformer_decoder = DeformableTransformerDecoder( self._num_filters[-1] * 2, depth=transformer_config.depth, - heads=transformer_config.heads, - dim_head=transformer_config.dim_head, + n_heads=transformer_config.n_heads, + dim_single_head=transformer_config.dim_single_head, dim_ffn=transformer_config.dim_ffn, - dropout=transformer_config.dropout_rate, - out_attention=transformer_config.out_att, + dropout=transformer_config.dropout, + out_attention=transformer_config.out_attn, n_points=transformer_config.get('n_points', 9), ) self.pos_embedding_type = transformer_config.get( @@ -499,7 +500,7 @@ def forward(self, x, batch_data_samples): spatial_shapes.prod(1).cumsum(0)[:-1], )) - transformer_out = self.transformer_layer( + transformer_out = self.transformer_decoder( ct_feat, self.pos_embedding, src, @@ -756,6 +757,14 @@ def get_targets_single(self, @MODELS.register_module() class MultiFrameDeformableDecoderRPN(BaseDecoderRPN): + """The original implementation of CenterFormer modules. + + The difference between this module and + `DeformableDecoderRPN` is that this module uses information from multi + frames. + + TODO: split this module into backbone、neck and head. + """ def __init__( self, @@ -810,15 +819,15 @@ def __init__( self.mtf_attention = MultiFrameSpatialAttention() self.time_embedding = nn.Linear(1, self._num_filters[0]) - self.transformer_layer = DeformableTransformer( + self.transformer_decoder = DeformableTransformerDecoder( self._num_filters[-1] * 2, depth=transformer_config.depth, - heads=transformer_config.heads, - levels=2 + self.frame, - dim_head=transformer_config.dim_head, + n_heads=transformer_config.n_heads, + n_levels=2 + self.frame, + dim_single_head=transformer_config.dim_single_head, dim_ffn=transformer_config.dim_ffn, - dropout=transformer_config.dropout_rate, - out_attention=transformer_config.out_att, + dropout=transformer_config.dropout, + out_attention=transformer_config.out_attn, n_points=transformer_config.get('n_points', 9), ) self.pos_embedding_type = transformer_config.get( @@ -939,7 +948,7 @@ def forward(self, x, example=None): spatial_shapes.prod(1).cumsum(0)[:-1], )) - transformer_out = self.transformer_layer( + transformer_out = self.transformer_decoder( ct_feat, self.pos_embedding, src, diff --git a/projects/centerformer/centerformer/multi_scale_deform_attn.py b/projects/centerformer/centerformer/multi_scale_deform_attn.py index ec79dc8e3c..892b26a06a 100644 --- a/projects/centerformer/centerformer/multi_scale_deform_attn.py +++ b/projects/centerformer/centerformer/multi_scale_deform_attn.py @@ -92,8 +92,10 @@ class MSDeformAttn(nn.Module): specified respectively. Args: - d_model (int, optional): input dimension. Defaults to 256. - d_head (int, optional): hidden dimension. Defaults to 64. + dim_model (int, optional): The input and output dimension in the model. + Defaults to 256. + dim_single_head (int, optional): hidden dimension in the single head. + Defaults to 64. n_levels (int, optional): number of feature levels. Defaults to 4. n_heads (int, optional): number of attention heads. Defaults to 8. n_points (int, optional): number of sampling points per attention head @@ -103,8 +105,8 @@ class MSDeformAttn(nn.Module): """ def __init__(self, - d_model=256, - d_head=64, + dim_model=256, + dim_single_head=64, n_levels=4, n_heads=8, n_points=4, @@ -113,20 +115,20 @@ def __init__(self, self.im2col_step = 64 - self.d_model = d_model - self.d_head = d_head + self.dim_model = dim_model + self.dim_single_head = dim_single_head self.n_levels = n_levels self.n_heads = n_heads self.n_points = n_points self.out_sample_loc = out_sample_loc - self.sampling_offsets = nn.Linear(d_model, + self.sampling_offsets = nn.Linear(dim_model, n_heads * n_levels * n_points * 2) - self.attention_weights = nn.Linear(d_model, + self.attention_weights = nn.Linear(dim_model, n_heads * n_levels * n_points) - self.value_proj = nn.Linear(d_model, d_head * n_heads) - self.output_proj = nn.Linear(d_head * n_heads, d_model) + self.value_proj = nn.Linear(dim_model, dim_single_head * n_heads) + self.output_proj = nn.Linear(dim_single_head * n_heads, dim_model) self._reset_parameters() @@ -190,7 +192,7 @@ def forward(self, value = self.value_proj(input_flatten) if input_padding_mask is not None: value = value.masked_fill(input_padding_mask[..., None], float(0)) - value = value.view(N, Len_in, self.n_heads, self.d_head) + value = value.view(N, Len_in, self.n_heads, self.dim_single_head) sampling_offsets = self.sampling_offsets(query).view( N, Len_q, self.n_heads, self.n_levels, self.n_points, 2) attention_weights = self.attention_weights(query).view( diff --git a/projects/centerformer/centerformer/transformer.py b/projects/centerformer/centerformer/transformer.py index 7353c757b7..2b41278434 100644 --- a/projects/centerformer/centerformer/transformer.py +++ b/projects/centerformer/centerformer/transformer.py @@ -40,16 +40,16 @@ class SelfAttention(nn.Module): def __init__(self, dim, - heads=8, - dim_head=64, + n_heads=8, + dim_single_head=64, dropout=0.0, out_attention=False): super().__init__() - inner_dim = dim_head * heads - project_out = not (heads == 1 and dim_head == dim) + inner_dim = dim_single_head * n_heads + project_out = not (n_heads == 1 and dim_single_head == dim) - self.heads = heads - self.scale = dim_head**-0.5 + self.n_heads = n_heads + self.scale = dim_single_head**-0.5 self.out_attention = out_attention self.attend = nn.Softmax(dim=-1) @@ -60,7 +60,7 @@ def __init__(self, if project_out else nn.Identity()) def forward(self, x): - _, _, _, h = *x.shape, self.heads + _, _, _, h = *x.shape, self.n_heads qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) @@ -81,8 +81,8 @@ class DeformableCrossAttention(nn.Module): def __init__( self, - d_model=256, - d_head=64, + dim_model=256, + dim_single_head=64, dropout=0.3, n_levels=3, n_heads=6, @@ -93,8 +93,8 @@ def __init__( # cross attention self.cross_attn = MSDeformAttn( - d_model, - d_head, + dim_model, + dim_single_head, n_levels, n_heads, n_points, @@ -133,8 +133,8 @@ def forward( return tgt -class DeformableTransformer(nn.Module): - """Deformable transformer. +class DeformableTransformerDecoder(nn.Module): + """Deformable transformer decoder. Note that the ``DeformableDetrTransformerDecoder`` in MMDet has different interfaces in multi-head-attention which is customized here. For example, @@ -146,10 +146,10 @@ class DeformableTransformer(nn.Module): def __init__( self, dim, - levels=3, + n_levels=3, depth=2, - heads=4, - dim_head=32, + n_heads=4, + dim_single_head=32, dim_ffn=256, dropout=0.0, out_attention=False, @@ -159,7 +159,7 @@ def __init__( self.out_attention = out_attention self.layers = nn.ModuleList([]) self.depth = depth - self.levels = levels + self.n_levels = n_levels self.n_points = n_points for _ in range(depth): @@ -169,8 +169,8 @@ def __init__( dim, SelfAttention( dim, - heads=heads, - dim_head=dim_head, + n_heads=n_heads, + dim_single_head=dim_single_head, dropout=dropout, out_attention=self.out_attention, ), @@ -179,9 +179,9 @@ def __init__( dim, DeformableCrossAttention( dim, - dim_head, - n_levels=levels, - n_heads=heads, + dim_single_head, + n_levels=n_levels, + n_heads=n_heads, dropout=dropout, n_points=n_points, out_sample_loc=self.out_attention, @@ -197,7 +197,7 @@ def forward(self, x, pos_embedding, src, src_spatial_shapes, if pos_embedding is not None: center_pos_embedding = pos_embedding(center_pos) reference_points = center_pos[:, :, - None, :].repeat(1, 1, self.levels, 1) + None, :].repeat(1, 1, self.n_levels, 1) for i, (self_attn, cross_attn, ff) in enumerate(self.layers): if self.out_attention: if center_pos_embedding is not None: diff --git a/projects/centerformer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py b/projects/centerformer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py index 207bd41abd..d740fbd630 100644 --- a/projects/centerformer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py +++ b/projects/centerformer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py @@ -53,11 +53,11 @@ norm_cfg=dict(type='SyncBN', eps=1e-3, momentum=0.01), transformer_config=dict( depth=2, - heads=6, - dim_head=64, + n_heads=6, + dim_single_head=64, dim_ffn=256, - dropout_rate=0.3, - out_att=False, + dropout=0.3, + out_attn=False, n_points=15, ), ), @@ -106,7 +106,7 @@ obj_num=1000, )) -data_root = 'data/waymo/kitti_format/' +data_root = 'data/waymo_mini/kitti_format/' db_sampler = dict( data_root=data_root, info_path=data_root + 'waymo_dbinfos_train.pkl', From 69db244905f7d5a14e832fa3e301d8fa5a7e623d Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Tue, 27 Dec 2022 20:54:01 +0800 Subject: [PATCH 08/18] add ut for disable_object_sample_hook --- mmdet3d/apis/inference.py | 4 +--- mmdet3d/engine/hooks/disable_object_sample_hook.py | 7 +++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/mmdet3d/apis/inference.py b/mmdet3d/apis/inference.py index f9726e1511..ee6bf16879 100644 --- a/mmdet3d/apis/inference.py +++ b/mmdet3d/apis/inference.py @@ -67,12 +67,10 @@ def init_model(config: Union[str, Path, Config], if checkpoint is not None: checkpoint = load_checkpoint(model, checkpoint, map_location='cpu') - if 'meta' in checkpoint: - dataset_meta = checkpoint['meta'].get('dataset_meta', None) # save the dataset_meta in the model for convenience if 'dataset_meta' in checkpoint.get('meta', {}): # mmdet3d 1.x - model.dataset_meta = dataset_meta + model.dataset_meta = checkpoint['meta']['dataset_meta'] elif 'CLASSES' in checkpoint.get('meta', {}): # < mmdet3d 1.x classes = checkpoint['meta']['CLASSES'] diff --git a/mmdet3d/engine/hooks/disable_object_sample_hook.py b/mmdet3d/engine/hooks/disable_object_sample_hook.py index 3a8182c45e..d1f3c2a09d 100644 --- a/mmdet3d/engine/hooks/disable_object_sample_hook.py +++ b/mmdet3d/engine/hooks/disable_object_sample_hook.py @@ -12,10 +12,9 @@ class DisableObjectSampleHook(Hook): """The hook of disabling augmentations during training. Args: - num_last_epochs (int): The number of latter epochs in the end of the - training to close the data augmentation. Default: 15. - skip_type_keys (list[str], optional): Sequence of type string to be - skipped in the data pipeline. Default: ('ObjectSample') + disable_after_epoch (int): The number of epochs after which + the ``ObjectSample`` will be closed in the training. + Defaults to 15. """ def __init__(self, disable_after_epoch: int = 15): From 283eab55e848243783fe5d866865e866e8180f13 Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Tue, 27 Dec 2022 20:56:32 +0800 Subject: [PATCH 09/18] modify data_root --- ..._second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/centerformer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py b/projects/centerformer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py index d740fbd630..70bed7c603 100644 --- a/projects/centerformer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py +++ b/projects/centerformer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py @@ -106,7 +106,7 @@ obj_num=1000, )) -data_root = 'data/waymo_mini/kitti_format/' +data_root = 'data/waymo/kitti_format/' db_sampler = dict( data_root=data_root, info_path=data_root + 'waymo_dbinfos_train.pkl', From 7a7b4f0bbfc89279fe8ef6002a7412510b4e3aa7 Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Wed, 28 Dec 2022 11:03:04 +0800 Subject: [PATCH 10/18] add ut --- .../test_disable_object_sample_hook.py | 75 +++++++++++++++++++ 1 file changed, 75 insertions(+) create mode 100644 tests/test_engine/test_hooks/test_disable_object_sample_hook.py diff --git a/tests/test_engine/test_hooks/test_disable_object_sample_hook.py b/tests/test_engine/test_hooks/test_disable_object_sample_hook.py new file mode 100644 index 0000000000..fcc1e3c8d8 --- /dev/null +++ b/tests/test_engine/test_hooks/test_disable_object_sample_hook.py @@ -0,0 +1,75 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from unittest import TestCase +from unittest.mock import Mock + +from mmdet3d.datasets.transforms import ObjectSample +from mmdet3d.engine.hooks import DisableObjectSampleHook + + +class TestDisableObjectSampleHook(TestCase): + + runner = Mock() + runner.train_dataloader = Mock() + runner.train_dataloader.dataset = Mock() + runner.train_dataloader.dataset.pipeline = Mock() + runner.train_dataloader._DataLoader__initialized = True + runner.train_dataloader.dataset.pipeline.transforms = [ + ObjectSample( + db_sampler=dict( + data_root='tests/data/waymo/kitti_format', + info_path= # noqa + 'tests/data/waymo/kitti_format/waymo_dbinfos_train.pkl', + rate=1.0, + prepare=dict( + filter_by_difficulty=[-1], + filter_by_min_points=dict(Car=5)), + classes=['Car'], + sample_groups=dict(Car=15), + )) + ] + + def test_is_model_wrapper_and_persistent_workers_on(self): + self.runner.train_dataloader.dataset.pipeline.transforms[ + 0].disabled = False + self.runner.train_dataloader.persistent_workers = True + hook = DisableObjectSampleHook(disable_after_epoch=15) + self.runner.epoch = 14 + hook.before_train_epoch(self.runner) + self.assertFalse(self.runner.train_dataloader.dataset.pipeline. + transforms[0].disabled) # noqa: E501 + + self.runner.epoch = 15 + hook.before_train_epoch(self.runner) + self.assertTrue(self.runner.train_dataloader.dataset.pipeline. + transforms[0].disabled) # noqa: E501 + self.assertTrue(hook._restart_dataloader) + self.assertFalse(self.runner.train_dataloader._DataLoader__initialized) + + self.runner.epoch = 16 + hook.before_train_epoch(self.runner) + self.assertTrue(self.runner.train_dataloader._DataLoader__initialized) + self.assertTrue(self.runner.train_dataloader.dataset.pipeline. + transforms[0].disabled) # noqa: E501 + + def test_not_model_wrapper_and_persistent_workers_off(self): + self.runner.train_dataloader.dataset.pipeline.transforms[ + 0].disabled = False + self.runner.train_dataloader.persistent_workers = False + hook = DisableObjectSampleHook(disable_after_epoch=15) + self.runner.epoch = 14 + hook.before_train_epoch(self.runner) + self.assertFalse(self.runner.train_dataloader.dataset.pipeline. + transforms[0].disabled) # noqa: E501 + + self.runner.epoch = 15 + hook.before_train_epoch(self.runner) + self.assertTrue(self.runner.train_dataloader.dataset.pipeline. + transforms[0].disabled) # noqa: E501 + self.assertFalse(hook._restart_dataloader) + self.assertTrue(self.runner.train_dataloader._DataLoader__initialized) + + self.runner.epoch = 16 + hook.before_train_epoch(self.runner) + self.assertTrue(self.runner.train_dataloader._DataLoader__initialized) + self.assertTrue(self.runner.train_dataloader.dataset.pipeline. + transforms[0].disabled) # noqa: E501 From 80380111f7a952c446364ded86e9bafc370e7976 Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Wed, 28 Dec 2022 16:47:26 +0800 Subject: [PATCH 11/18] update readme --- projects/centerformer/README.md | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/projects/centerformer/README.md b/projects/centerformer/README.md index fc862da7fc..0ef6cd064d 100644 --- a/projects/centerformer/README.md +++ b/projects/centerformer/README.md @@ -79,11 +79,13 @@ python tools/train.py projects/centerformer/configs/centerformer_voxel01_second- ## Results and models -### CenterFormer +### Waymo -| Backbone | Voxel type (voxel size) | Multi-Class NMS | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download | -| :-------------------------------------------------------------------------------------: | :---------------------: | :-------------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :----------------------: | -| [SECFPN_WithAtten](./centerpoint_01voxel_second_secfpn_circlenms_4x8_cyclic_20e_nus.py) | voxel (0.1) | ✓ | 5.2 | | | | | | [model](<>) \| [log](<>) | +| Backbone | Load Interval | Voxel type (voxel size) | Multi-Class NMS | Multi-frames | Mem (GB) | Inf time (fps) | mAP@L1 | mAPH@L1 | mAP@L2 | **mAPH@L2** | Download | +| :----------------------------------------------------------------------------------------------------------------: | :-----------: | :---------------------: | :-------------: | :----------: | :------: | :------------: | :----: | :-----: | :----: | :---------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| [SECFPN_WithAttention](./configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py) | 5 | voxel (0.1) | ✓ | × | 14.8 | | 72.2 | 69.5 | 65.9 | 63.3 | [log](https://download.openmmlab.com/mmdetection3d/v1.1.0_models/centerformer/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class_20221227_205613-70c9ad37.json) | + +**Note** that `SECFPN_WithAttention` denotes both SECOND and SECONDFPN with ChannelAttention and SpatialAttention. ## Citation From 7384c0c5f306c590a0046bd5e7a8073e1b003709 Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Wed, 28 Dec 2022 17:13:15 +0800 Subject: [PATCH 12/18] polish code --- projects/centerformer/README.md | 2 +- .../centerformer/centerformer/centerformer_backbone.py | 2 ++ .../centerformer/centerformer/centerformer_head.py | 2 -- projects/centerformer/centerformer/losses.py | 10 ++++++---- .../centerformer/multi_scale_deform_attn.py | 2 ++ projects/centerformer/centerformer/transformer.py | 2 ++ 6 files changed, 13 insertions(+), 7 deletions(-) diff --git a/projects/centerformer/README.md b/projects/centerformer/README.md index 0ef6cd064d..8f958daf87 100644 --- a/projects/centerformer/README.md +++ b/projects/centerformer/README.md @@ -71,7 +71,7 @@ python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=${N ### Testing commands -In MMDetection's root directory, run the following command to test the model: +In MMDetection3D's root directory, run the following command to test the model: ```bash python tools/train.py projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py ${CHECKPOINT_PATH} diff --git a/projects/centerformer/centerformer/centerformer_backbone.py b/projects/centerformer/centerformer/centerformer_backbone.py index 13082a0621..e29d1228ef 100644 --- a/projects/centerformer/centerformer/centerformer_backbone.py +++ b/projects/centerformer/centerformer/centerformer_backbone.py @@ -1,3 +1,5 @@ +# modify from https://github.com/TuSimple/centerformer/blob/master/det3d/models/necks/rpn_transformer.py # noqa + from typing import List, Tuple import numpy as np diff --git a/projects/centerformer/centerformer/centerformer_head.py b/projects/centerformer/centerformer/centerformer_head.py index 051e298df1..f1e5cbe937 100644 --- a/projects/centerformer/centerformer/centerformer_head.py +++ b/projects/centerformer/centerformer/centerformer_head.py @@ -289,7 +289,6 @@ def loss(self, preds_dicts, example, **kwargs): return losses - @torch.no_grad() def predict(self, preds_dicts, batch_input_metas, **kwargs): """decode, nms, then return the detection result. @@ -410,7 +409,6 @@ def predict(self, preds_dicts, batch_input_metas, **kwargs): return ret_list - @torch.no_grad() def post_processing( self, img_metas, diff --git a/projects/centerformer/centerformer/losses.py b/projects/centerformer/centerformer/losses.py index daa93c6a2d..e59dc8f984 100644 --- a/projects/centerformer/centerformer/losses.py +++ b/projects/centerformer/centerformer/losses.py @@ -1,3 +1,5 @@ +# modify from https://github.com/TuSimple/centerformer/blob/master/det3d/models/losses/centernet_loss.py # noqa + import torch from torch import nn @@ -35,10 +37,10 @@ def __init__(self, focal_factor=2): def forward(self, out, target, ind, mask, cat): ''' - Arguments: - out, target: B x C x H x W - ind, mask: B x M - cat (category id for peaks): B x M + Args: + out, target: B x C x H x W + ind, mask: B x M + cat (category id for peaks): B x M ''' mask = mask.float() gt = torch.pow(1 - target, 4) diff --git a/projects/centerformer/centerformer/multi_scale_deform_attn.py b/projects/centerformer/centerformer/multi_scale_deform_attn.py index 892b26a06a..6c39af9cd8 100644 --- a/projects/centerformer/centerformer/multi_scale_deform_attn.py +++ b/projects/centerformer/centerformer/multi_scale_deform_attn.py @@ -1,3 +1,5 @@ +# modify from https://github.com/TuSimple/centerformer/blob/master/det3d/models/ops/modules/ms_deform_attn.py # noqa + import math from typing import Optional diff --git a/projects/centerformer/centerformer/transformer.py b/projects/centerformer/centerformer/transformer.py index 2b41278434..88b8ff2a55 100644 --- a/projects/centerformer/centerformer/transformer.py +++ b/projects/centerformer/centerformer/transformer.py @@ -1,3 +1,5 @@ +# modify from https://github.com/TuSimple/centerformer/blob/master/det3d/models/utils/transformer.py # noqa + import torch from einops import rearrange from mmcv.cnn.bricks.activation import GELU From ca1ee7688888277353e9a906157a554d592cbb31 Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Wed, 28 Dec 2022 17:21:47 +0800 Subject: [PATCH 13/18] fix docstring --- projects/centerformer/centerformer/bbox_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/projects/centerformer/centerformer/bbox_ops.py b/projects/centerformer/centerformer/bbox_ops.py index d850eaf8ee..8a7180d1cb 100644 --- a/projects/centerformer/centerformer/bbox_ops.py +++ b/projects/centerformer/centerformer/bbox_ops.py @@ -10,8 +10,8 @@ def nms_iou3d(boxes, scores, thresh, pre_maxsize=None, post_max_size=None): `post_max_size` before and after NMS respectively. Args: - boxes (Tensor): Input boxes with the shape of [N, 5] - ([x1, y1, x2, y2, ry]). + boxes (Tensor): Input boxes with the shape of [N, 7] + ([cx, cy, cz, l, w, h, theta]). scores (Tensor): Scores of boxes with the shape of [N]. thresh (float): Overlap threshold of NMS. pre_max_size (int, optional): Max size of boxes before NMS. From 91ca015607872ceafec83161916eb46d2dd1cca0 Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Tue, 3 Jan 2023 19:25:51 +0800 Subject: [PATCH 14/18] resolve comments --- projects/{centerformer => CenterFormer}/README.md | 0 .../centerformer/__init__.py | 0 .../centerformer/bbox_ops.py | 4 ++-- .../centerformer/centerformer.py | 14 -------------- .../centerformer/centerformer_backbone.py | 0 .../centerformer/centerformer_head.py | 0 .../centerformer/losses.py | 0 .../centerformer/multi_scale_deform_attn.py | 0 .../centerformer/transformer.py | 0 ...secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py | 0 10 files changed, 2 insertions(+), 16 deletions(-) rename projects/{centerformer => CenterFormer}/README.md (100%) rename projects/{centerformer => CenterFormer}/centerformer/__init__.py (100%) rename projects/{centerformer => CenterFormer}/centerformer/bbox_ops.py (94%) rename projects/{centerformer => CenterFormer}/centerformer/centerformer.py (94%) rename projects/{centerformer => CenterFormer}/centerformer/centerformer_backbone.py (100%) rename projects/{centerformer => CenterFormer}/centerformer/centerformer_head.py (100%) rename projects/{centerformer => CenterFormer}/centerformer/losses.py (100%) rename projects/{centerformer => CenterFormer}/centerformer/multi_scale_deform_attn.py (100%) rename projects/{centerformer => CenterFormer}/centerformer/transformer.py (100%) rename projects/{centerformer => CenterFormer}/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py (100%) diff --git a/projects/centerformer/README.md b/projects/CenterFormer/README.md similarity index 100% rename from projects/centerformer/README.md rename to projects/CenterFormer/README.md diff --git a/projects/centerformer/centerformer/__init__.py b/projects/CenterFormer/centerformer/__init__.py similarity index 100% rename from projects/centerformer/centerformer/__init__.py rename to projects/CenterFormer/centerformer/__init__.py diff --git a/projects/centerformer/centerformer/bbox_ops.py b/projects/CenterFormer/centerformer/bbox_ops.py similarity index 94% rename from projects/centerformer/centerformer/bbox_ops.py rename to projects/CenterFormer/centerformer/bbox_ops.py index 8a7180d1cb..2b13a2ba13 100644 --- a/projects/centerformer/centerformer/bbox_ops.py +++ b/projects/CenterFormer/centerformer/bbox_ops.py @@ -5,8 +5,8 @@ def nms_iou3d(boxes, scores, thresh, pre_maxsize=None, post_max_size=None): - """NMS function GPU implementation (using IoU3D) The difference of this - implementation with nms3d in MMCV is that we add `pre_maxsize` and + """NMS function GPU implementation (using IoU3D) The difference between + this implementation and nms3d in MMCV is that we add `pre_maxsize` and `post_max_size` before and after NMS respectively. Args: diff --git a/projects/centerformer/centerformer/centerformer.py b/projects/CenterFormer/centerformer/centerformer.py similarity index 94% rename from projects/centerformer/centerformer/centerformer.py rename to projects/CenterFormer/centerformer/centerformer.py index 5107b92993..bfffdf26d7 100644 --- a/projects/centerformer/centerformer/centerformer.py +++ b/projects/CenterFormer/centerformer/centerformer.py @@ -82,17 +82,6 @@ def with_backbone(self): """bool: Whether the detector has a 3D backbone.""" return hasattr(self, 'backbone') and self.backbone is not None - @property - def with_fusion(self): - """bool: Whether the detector has a fusion layer.""" - return hasattr(self, - 'pts_fusion_layer') and self.fusion_layer is not None - - @property - def with_neck(self): - """bool: Whether the detector has a neck in 3D detector branch.""" - return hasattr(self, 'neck') and self.neck is not None - @property def with_voxel_encoder(self): """bool: Whether the detector has a voxel encoder.""" @@ -105,9 +94,6 @@ def with_middle_encoder(self): return hasattr(self, 'middle_encoder') and self.middle_encoder is not None - def _forward(self): - pass - def extract_feat(self, batch_inputs_dict: dict, batch_input_metas: List[dict]) -> tuple: """Extract features from images and points. diff --git a/projects/centerformer/centerformer/centerformer_backbone.py b/projects/CenterFormer/centerformer/centerformer_backbone.py similarity index 100% rename from projects/centerformer/centerformer/centerformer_backbone.py rename to projects/CenterFormer/centerformer/centerformer_backbone.py diff --git a/projects/centerformer/centerformer/centerformer_head.py b/projects/CenterFormer/centerformer/centerformer_head.py similarity index 100% rename from projects/centerformer/centerformer/centerformer_head.py rename to projects/CenterFormer/centerformer/centerformer_head.py diff --git a/projects/centerformer/centerformer/losses.py b/projects/CenterFormer/centerformer/losses.py similarity index 100% rename from projects/centerformer/centerformer/losses.py rename to projects/CenterFormer/centerformer/losses.py diff --git a/projects/centerformer/centerformer/multi_scale_deform_attn.py b/projects/CenterFormer/centerformer/multi_scale_deform_attn.py similarity index 100% rename from projects/centerformer/centerformer/multi_scale_deform_attn.py rename to projects/CenterFormer/centerformer/multi_scale_deform_attn.py diff --git a/projects/centerformer/centerformer/transformer.py b/projects/CenterFormer/centerformer/transformer.py similarity index 100% rename from projects/centerformer/centerformer/transformer.py rename to projects/CenterFormer/centerformer/transformer.py diff --git a/projects/centerformer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py b/projects/CenterFormer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py similarity index 100% rename from projects/centerformer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py rename to projects/CenterFormer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py From 5c959ff3b8672dfac639e32f4c0b500378fa3592 Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Tue, 3 Jan 2023 19:27:48 +0800 Subject: [PATCH 15/18] modify project names --- projects/CenterFormer/README.md | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/projects/CenterFormer/README.md b/projects/CenterFormer/README.md index 8f958daf87..697ac3ce02 100644 --- a/projects/CenterFormer/README.md +++ b/projects/CenterFormer/README.md @@ -60,13 +60,13 @@ We follow the below style to name config files. Contributors are advised to foll In MMDetection3D's root directory, run the following command to train the model: ```bash -python tools/train.py projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py +python tools/train.py projects/CenterFormer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py ``` For multi-gpu training, run: ```bash -python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=${NUM_GPUS} --master_port=29506 --master_addr="127.0.0.1" tools/train.py projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py +python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=${NUM_GPUS} --master_port=29506 --master_addr="127.0.0.1" tools/train.py projects/CenterFormer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py ``` ### Testing commands @@ -74,7 +74,7 @@ python -m torch.distributed.launch --nnodes=1 --node_rank=0 --nproc_per_node=${N In MMDetection3D's root directory, run the following command to test the model: ```bash -python tools/train.py projects/centerformer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py ${CHECKPOINT_PATH} +python tools/train.py projects/CenterFormer/configs/centerformer_voxel01_second-atten_secfpn-atten_4xb4-cyclic-20e_waymoD5-3d-class.py ${CHECKPOINT_PATH} ``` ## Results and models From 1e5360c6ee9e54a3a928460636ada7732088da71 Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Tue, 3 Jan 2023 19:34:19 +0800 Subject: [PATCH 16/18] modify project names and add _forward --- projects/CenterFormer/centerformer/centerformer.py | 3 +++ ...second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/projects/CenterFormer/centerformer/centerformer.py b/projects/CenterFormer/centerformer/centerformer.py index bfffdf26d7..6b8b64dcd9 100644 --- a/projects/CenterFormer/centerformer/centerformer.py +++ b/projects/CenterFormer/centerformer/centerformer.py @@ -94,6 +94,9 @@ def with_middle_encoder(self): return hasattr(self, 'middle_encoder') and self.middle_encoder is not None + def _forward(self): + pass + def extract_feat(self, batch_inputs_dict: dict, batch_input_metas: List[dict]) -> tuple: """Extract features from images and points. diff --git a/projects/CenterFormer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py b/projects/CenterFormer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py index 70bed7c603..df19f3c6a6 100644 --- a/projects/CenterFormer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py +++ b/projects/CenterFormer/configs/centerformer_voxel01_second-attn_secfpn-attn_4xb4-cyclic-20e_waymoD5-3d-class.py @@ -1,6 +1,6 @@ _base_ = ['mmdet3d::_base_/default_runtime.py'] custom_imports = dict( - imports=['projects.centerformer.centerformer'], allow_failed_imports=False) + imports=['projects.CenterFormer.centerformer'], allow_failed_imports=False) # model settings # Voxel size for voxel encoder From 66da32b1cf57219da7c25bcd5529bdbc1cb15007 Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Tue, 3 Jan 2023 19:36:32 +0800 Subject: [PATCH 17/18] fix docstring --- projects/CenterFormer/centerformer/bbox_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/CenterFormer/centerformer/bbox_ops.py b/projects/CenterFormer/centerformer/bbox_ops.py index 2b13a2ba13..dca5d767e6 100644 --- a/projects/CenterFormer/centerformer/bbox_ops.py +++ b/projects/CenterFormer/centerformer/bbox_ops.py @@ -5,7 +5,7 @@ def nms_iou3d(boxes, scores, thresh, pre_maxsize=None, post_max_size=None): - """NMS function GPU implementation (using IoU3D) The difference between + """NMS function GPU implementation (using IoU3D). The difference between this implementation and nms3d in MMCV is that we add `pre_maxsize` and `post_max_size` before and after NMS respectively. From 27df975dbafe429973a35c69cab9b564c6465a2b Mon Sep 17 00:00:00 2001 From: JingweiZhang12 Date: Thu, 5 Jan 2023 09:09:27 +0800 Subject: [PATCH 18/18] remove disable_tf32 --- tools/train.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tools/train.py b/tools/train.py index b1b1d27b8b..7c65903c24 100644 --- a/tools/train.py +++ b/tools/train.py @@ -21,11 +21,6 @@ def parse_args(): action='store_true', default=False, help='enable automatic-mixed-precision training') - parser.add_argument( - '--disable-tf32', - action='store_true', - default=False, - help='disable TF32 on A100 GPUs') parser.add_argument( '--auto-scale-lr', action='store_true', @@ -99,12 +94,6 @@ def main(): cfg.optim_wrapper.type = 'AmpOptimWrapper' cfg.optim_wrapper.loss_scale = 'dynamic' - if args.disable_tf32: - import torch - torch.backends.cuda.matmul.allow_tf32 = False - torch.backends.cudnn.allow_tf32 = False - print_log('Disable TF32 on A100 GPUs', logger='current') - # enable automatically scaling LR if args.auto_scale_lr: if 'auto_scale_lr' in cfg and \