Skip to content

Commit

Permalink
Add type hint for middle_encoder and voxel_encoder (#2556)
Browse files Browse the repository at this point in the history
* 2023/05/26 add type hint

* 2023/05/26 modify ugly typehint
  • Loading branch information
A-new-b authored May 29, 2023
1 parent 8e634dd commit fa724b1
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 139 deletions.
16 changes: 11 additions & 5 deletions mmdet3d/models/middle_encoders/pillar_scatter.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List

import torch
from torch import nn
from torch import Tensor, nn

from mmdet3d.registry import MODELS

Expand All @@ -16,14 +18,17 @@ class PointPillarsScatter(nn.Module):
output_shape (list[int]): Required output shape of features.
"""

def __init__(self, in_channels, output_shape):
def __init__(self, in_channels: int, output_shape: List[int]):
super().__init__()
self.output_shape = output_shape
self.ny = output_shape[0]
self.nx = output_shape[1]
self.in_channels = in_channels

def forward(self, voxel_features, coors, batch_size=None):
def forward(self,
voxel_features: Tensor,
coors: Tensor,
batch_size: int = None) -> Tensor:
"""Foraward function to scatter features."""
# TODO: rewrite the function in a batch manner
# no need to deal with different batch cases
Expand All @@ -32,7 +37,7 @@ def forward(self, voxel_features, coors, batch_size=None):
else:
return self.forward_single(voxel_features, coors)

def forward_single(self, voxel_features, coors):
def forward_single(self, voxel_features: Tensor, coors: Tensor) -> Tensor:
"""Scatter features of single sample.
Args:
Expand All @@ -56,7 +61,8 @@ def forward_single(self, voxel_features, coors):
canvas = canvas.view(1, self.in_channels, self.ny, self.nx)
return canvas

def forward_batch(self, voxel_features, coors, batch_size):
def forward_batch(self, voxel_features: Tensor, coors: Tensor,
batch_size: int) -> Tensor:
"""Scatter features of single sample.
Args:
Expand Down
82 changes: 48 additions & 34 deletions mmdet3d/models/middle_encoders/sparse_encoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Tuple
from typing import Dict, List, Optional, Tuple, Union

import torch
from mmcv.ops import points_in_boxes_all, three_interpolate, three_nn
Expand All @@ -18,6 +18,8 @@
else:
from mmcv.ops import SparseConvTensor, SparseSequential

TwoTupleIntType = Tuple[Tuple[int]]


@MODELS.register_module()
class SparseEncoder(nn.Module):
Expand All @@ -26,7 +28,7 @@ class SparseEncoder(nn.Module):
Args:
in_channels (int): The number of input channels.
sparse_shape (list[int]): The sparse shape of input tensor.
order (list[str], optional): Order of conv module.
order (tuple[str], optional): Order of conv module.
Defaults to ('conv', 'norm', 'act').
norm_cfg (dict, optional): Config of normalization layer. Defaults to
dict(type='BN1d', eps=1e-3, momentum=0.01).
Expand All @@ -46,19 +48,24 @@ class SparseEncoder(nn.Module):
Default to False.
"""

def __init__(self,
in_channels,
sparse_shape,
order=('conv', 'norm', 'act'),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
base_channels=16,
output_channels=128,
encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64,
64)),
encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
1)),
block_type='conv_module',
return_middle_feats=False):
def __init__(
self,
in_channels: int,
sparse_shape: List[int],
order: Optional[Tuple[str]] = ('conv', 'norm', 'act'),
norm_cfg: Optional[dict] = dict(
type='BN1d', eps=1e-3, momentum=0.01),
base_channels: Optional[int] = 16,
output_channels: Optional[int] = 128,
encoder_channels: Optional[TwoTupleIntType] = ((16, ), (32, 32,
32),
(64, 64,
64), (64, 64, 64)),
encoder_paddings: Optional[TwoTupleIntType] = ((1, ), (1, 1, 1),
(1, 1, 1),
((0, 1, 1), 1, 1)),
block_type: Optional[str] = 'conv_module',
return_middle_feats: Optional[bool] = False):
super().__init__()
assert block_type in ['conv_module', 'basicblock']
self.sparse_shape = sparse_shape
Expand Down Expand Up @@ -112,7 +119,8 @@ def __init__(self,
conv_type='SparseConv3d')

@amp.autocast(enabled=False)
def forward(self, voxel_features, coors, batch_size):
def forward(self, voxel_features: Tensor, coors: Tensor,
batch_size: int) -> Union[Tensor, Tuple[Tensor, list]]:
"""Forward of SparseEncoder.
Args:
Expand Down Expand Up @@ -154,12 +162,14 @@ def forward(self, voxel_features, coors, batch_size):
else:
return spatial_features

def make_encoder_layers(self,
make_block,
norm_cfg,
in_channels,
block_type='conv_module',
conv_cfg=dict(type='SubMConv3d')):
def make_encoder_layers(
self,
make_block: nn.Module,
norm_cfg: Dict,
in_channels: int,
block_type: Optional[str] = 'conv_module',
conv_cfg: Optional[dict] = dict(type='SubMConv3d')
) -> int:
"""make encoder layers using sparse convs.
Args:
Expand Down Expand Up @@ -256,18 +266,22 @@ class SparseEncoderSASSD(SparseEncoder):
Defaults to 'conv_module'.
"""

def __init__(self,
in_channels: int,
sparse_shape: List[int],
order: Tuple[str] = ('conv', 'norm', 'act'),
norm_cfg: dict = dict(type='BN1d', eps=1e-3, momentum=0.01),
base_channels: int = 16,
output_channels: int = 128,
encoder_channels: Tuple[tuple] = ((16, ), (32, 32, 32),
(64, 64, 64), (64, 64, 64)),
encoder_paddings: Tuple[tuple] = ((1, ), (1, 1, 1), (1, 1, 1),
((0, 1, 1), 1, 1)),
block_type: str = 'conv_module'):
def __init__(
self,
in_channels: int,
sparse_shape: List[int],
order: Tuple[str] = ('conv', 'norm', 'act'),
norm_cfg: dict = dict(type='BN1d', eps=1e-3, momentum=0.01),
base_channels: int = 16,
output_channels: int = 128,
encoder_channels: Optional[TwoTupleIntType] = ((16, ), (32, 32,
32),
(64, 64,
64), (64, 64, 64)),
encoder_paddings: Optional[TwoTupleIntType] = ((1, ), (1, 1, 1),
(1, 1, 1),
((0, 1, 1), 1, 1)),
block_type: str = 'conv_module'):
super(SparseEncoderSASSD, self).__init__(
in_channels=in_channels,
sparse_shape=sparse_shape,
Expand Down
60 changes: 39 additions & 21 deletions mmdet3d/models/middle_encoders/sparse_unet.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Optional, Tuple

import torch
from torch import Tensor, nn

from mmdet3d.models.layers.spconv import IS_SPCONV2_AVAILABLE

Expand All @@ -14,6 +17,8 @@
from mmdet3d.models.layers.sparse_block import replace_feature
from mmdet3d.registry import MODELS

TwoTupleIntType = Tuple[Tuple[int]]


@MODELS.register_module()
class SparseUNet(BaseModule):
Expand All @@ -35,21 +40,28 @@ class SparseUNet(BaseModule):
decoder_paddings (tuple[tuple[int]]): Paddings of each decode block.
"""

def __init__(self,
in_channels,
sparse_shape,
order=('conv', 'norm', 'act'),
norm_cfg=dict(type='BN1d', eps=1e-3, momentum=0.01),
base_channels=16,
output_channels=128,
encoder_channels=((16, ), (32, 32, 32), (64, 64, 64), (64, 64,
64)),
encoder_paddings=((1, ), (1, 1, 1), (1, 1, 1), ((0, 1, 1), 1,
1)),
decoder_channels=((64, 64, 64), (64, 64, 32), (32, 32, 16),
(16, 16, 16)),
decoder_paddings=((1, 0), (1, 0), (0, 0), (0, 1)),
init_cfg=None):
def __init__(
self,
in_channels: int,
sparse_shape: List[int],
order: Tuple[str] = ('conv', 'norm', 'act'),
norm_cfg: dict = dict(type='BN1d', eps=1e-3, momentum=0.01),
base_channels: int = 16,
output_channels: int = 128,
encoder_channels: Optional[TwoTupleIntType] = ((16, ), (32, 32,
32),
(64, 64,
64), (64, 64, 64)),
encoder_paddings: Optional[TwoTupleIntType] = ((1, ), (1, 1, 1),
(1, 1, 1),
((0, 1, 1), 1, 1)),
decoder_channels: Optional[TwoTupleIntType] = ((64, 64,
64), (64, 64, 32),
(32, 32,
16), (16, 16, 16)),
decoder_paddings: Optional[TwoTupleIntType] = ((1, 0), (1, 0),
(0, 0), (0, 1)),
init_cfg: bool = None):
super().__init__(init_cfg=init_cfg)
self.sparse_shape = sparse_shape
self.in_channels = in_channels
Expand Down Expand Up @@ -101,7 +113,8 @@ def __init__(self,
indice_key='spconv_down2',
conv_type='SparseConv3d')

def forward(self, voxel_features, coors, batch_size):
def forward(self, voxel_features: Tensor, coors: Tensor,
batch_size: int) -> Dict[str, Tensor]:
"""Forward of SparseUNet.
Args:
Expand Down Expand Up @@ -152,8 +165,10 @@ def forward(self, voxel_features, coors, batch_size):

return ret

def decoder_layer_forward(self, x_lateral, x_bottom, lateral_layer,
merge_layer, upsample_layer):
def decoder_layer_forward(
self, x_lateral: SparseConvTensor, x_bottom: SparseConvTensor,
lateral_layer: SparseBasicBlock, merge_layer: SparseSequential,
upsample_layer: SparseSequential) -> SparseConvTensor:
"""Forward of upsample and residual block.
Args:
Expand All @@ -176,7 +191,8 @@ def decoder_layer_forward(self, x_lateral, x_bottom, lateral_layer,
return x

@staticmethod
def reduce_channel(x, out_channels):
def reduce_channel(x: SparseConvTensor,
out_channels: int) -> SparseConvTensor:
"""reduce channel for element-wise addition.
Args:
Expand All @@ -194,7 +210,8 @@ def reduce_channel(x, out_channels):
x = replace_feature(x, features.view(n, out_channels, -1).sum(dim=2))
return x

def make_encoder_layers(self, make_block, norm_cfg, in_channels):
def make_encoder_layers(self, make_block: nn.Module, norm_cfg: dict,
in_channels: int) -> int:
"""make encoder layers using sparse convs.
Args:
Expand Down Expand Up @@ -240,7 +257,8 @@ def make_encoder_layers(self, make_block, norm_cfg, in_channels):
self.encoder_layers.add_module(stage_name, stage_layers)
return out_channels

def make_decoder_layers(self, make_block, norm_cfg, in_channels):
def make_decoder_layers(self, make_block: nn.Module, norm_cfg: dict,
in_channels: int) -> int:
"""make decoder layers using sparse convs.
Args:
Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/models/middle_encoders/voxel_set_abstraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from mmcv.cnn import ConvModule
from mmcv.ops.furthest_point_sample import furthest_point_sample
from mmengine.model import BaseModule
from torch import Tensor

from mmdet3d.registry import MODELS
from mmdet3d.utils import InstanceList


def bilinear_interpolate_torch(inputs, x, y):
def bilinear_interpolate_torch(inputs: Tensor, x: Tensor, y: Tensor) -> Tensor:
"""Bilinear interpolate for inputs."""
x0 = torch.floor(x).long()
x1 = x0 + 1
Expand Down
Loading

0 comments on commit fa724b1

Please sign in to comment.