Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhance] GroupFree3d inherits BaseModule From MMCV #704

Merged
merged 1 commit into from
Jul 21, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions mmdet3d/models/dense_heads/groupfree3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import numpy as np
import torch
from mmcv import ConfigDict
from mmcv.cnn import ConvModule
from mmcv.cnn import ConvModule, xavier_init
from mmcv.cnn.bricks.transformer import (build_positional_encoding,
build_transformer_layer)
from mmcv.runner import force_fp32
from mmcv.runner import BaseModule, force_fp32
from torch import nn as nn
from torch.nn import functional as F

Expand All @@ -19,7 +19,7 @@
EPS = 1e-6


class PointsObjClsModule(nn.Module):
class PointsObjClsModule(BaseModule):
"""object candidate point prediction from seed point features.

Args:
Expand All @@ -39,8 +39,9 @@ def __init__(self,
num_convs=3,
conv_cfg=dict(type='Conv1d'),
norm_cfg=dict(type='BN1d'),
act_cfg=dict(type='ReLU')):
super().__init__()
act_cfg=dict(type='ReLU'),
init_cfg=None):
super().__init__(init_cfg=init_cfg)
conv_channels = [in_channel for _ in range(num_convs - 1)]
conv_channels.append(1)

Expand Down Expand Up @@ -104,7 +105,7 @@ def forward(self, xyz, features, sample_inds):


@HEADS.register_module()
class GroupFree3DHead(nn.Module):
class GroupFree3DHead(BaseModule):
r"""Bbox head of `Group-Free 3D <https://arxiv.org/abs/2104.00678>`_.

Args:
Expand Down Expand Up @@ -162,8 +163,9 @@ def __init__(self,
size_class_loss=None,
size_res_loss=None,
size_reg_loss=None,
semantic_loss=None):
super(GroupFree3DHead, self).__init__()
semantic_loss=None,
init_cfg=None):
super(GroupFree3DHead, self).__init__(init_cfg=init_cfg)
self.num_classes = num_classes
self.train_cfg = train_cfg
self.test_cfg = test_cfg
Expand Down Expand Up @@ -251,15 +253,13 @@ def init_weights(self):
# initialize transformer
for m in self.decoder_layers.parameters():
if m.dim() > 1:
nn.init.xavier_uniform_(m)

xavier_init(m, distribution='uniform')
for m in self.decoder_self_posembeds.parameters():
if m.dim() > 1:
nn.init.xavier_uniform_(m)

xavier_init(m, distribution='uniform')
for m in self.decoder_cross_posembeds.parameters():
if m.dim() > 1:
nn.init.xavier_uniform_(m)
xavier_init(m, distribution='uniform')

def _get_cls_out_channels(self):
"""Return the channel number of classification outputs."""
Expand Down