diff --git a/docs/en/getting_started.md b/docs/en/getting_started.md index 0c3125e48b..bb02d885b9 100644 --- a/docs/en/getting_started.md +++ b/docs/en/getting_started.md @@ -76,14 +76,14 @@ You can check the supported CUDA version for precompiled packages on the [PyTorc `E.g. 1` If you have CUDA 10.1 installed under `/usr/local/cuda` and would like to install PyTorch 1.5, you need to install the prebuilt PyTorch with CUDA 10.1. -```python +```shell conda install pytorch==1.5.0 cudatoolkit=10.1 torchvision==0.6.0 -c pytorch ``` `E.g. 2` If you have CUDA 9.2 installed under `/usr/local/cuda` and would like to install PyTorch 1.3.1., you need to install the prebuilt PyTorch with CUDA 9.2. -```python +```shell conda install pytorch=1.3.1 cudatoolkit=9.2 torchvision=0.4.2 -c pytorch ``` @@ -192,6 +192,13 @@ you can install it before installing MMCV. 4. Some dependencies are optional. Simply running `pip install -v -e .` will only install the minimum runtime requirements. To use optional dependencies like `albumentations` and `imagecorruptions` either install them manually with `pip install -r requirements/optional.txt` or specify desired extras when calling `pip` (e.g. `pip install -v -e .[optional]`). Valid keys for the extras field are: `all`, `tests`, `build`, and `optional`. + We also support Minkowski Engine as a sparse convolution backend. If necessary please follow original [installation guide](https://github.com/NVIDIA/MinkowskiEngine#installation) or use `pip`: + + ```shell + conda install openblas-devel -c anaconda + pip install -U git+https://github.com/NVIDIA/MinkowskiEngine -v --no-deps --install-option="--blas_include_dirs=/opt/conda/include" --install-option="--blas=openblas" + ``` + 5. The code can not be built for CPU only environment (where CUDA isn't available) for now. ## Another option: Docker Image diff --git a/mmdet3d/models/backbones/__init__.py b/mmdet3d/models/backbones/__init__.py index 9403bd72c5..d51c16d2f6 100644 --- a/mmdet3d/models/backbones/__init__.py +++ b/mmdet3d/models/backbones/__init__.py @@ -2,6 +2,7 @@ from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt from .dgcnn import DGCNNBackbone from .dla import DLANet +from .mink_resnet import MinkResNet from .multi_backbone import MultiBackbone from .nostem_regnet import NoStemRegNet from .pointnet2_sa_msg import PointNet2SAMSG @@ -11,5 +12,5 @@ __all__ = [ 'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet', 'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG', - 'MultiBackbone', 'DLANet' + 'MultiBackbone', 'DLANet', 'MinkResNet' ] diff --git a/mmdet3d/models/backbones/mink_resnet.py b/mmdet3d/models/backbones/mink_resnet.py new file mode 100644 index 0000000000..35a79ce233 --- /dev/null +++ b/mmdet3d/models/backbones/mink_resnet.py @@ -0,0 +1,116 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Follow https://github.com/NVIDIA/MinkowskiEngine/blob/master/examples/resnet.py # noqa +# and mmcv.cnn.ResNet +try: + import MinkowskiEngine as ME + from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck +except ImportError: + import warnings + warnings.warn( + 'Please follow `getting_started.md` to install MinkowskiEngine.`') + # blocks are used in the static part of MinkResNet + BasicBlock, Bottleneck = None, None + +import torch.nn as nn + +from mmdet3d.models.builder import BACKBONES + + +@BACKBONES.register_module() +class MinkResNet(nn.Module): + r"""Minkowski ResNet backbone. See `4D Spatio-Temporal ConvNets + `_ for more details. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + in_channels (ont): Number of input channels, 3 for RGB. + num_stages (int, optional): Resnet stages. Default: 4. + pool (bool, optional): Add max pooling after first conv if True. + Default: True. + """ + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, depth, in_channels, num_stages=4, pool=True): + super(MinkResNet, self).__init__() + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + assert 4 >= num_stages >= 1 + block, stage_blocks = self.arch_settings[depth] + stage_blocks = stage_blocks[:num_stages] + self.num_stages = num_stages + self.pool = pool + + self.inplanes = 64 + self.conv1 = ME.MinkowskiConvolution( + in_channels, self.inplanes, kernel_size=3, stride=2, dimension=3) + # May be BatchNorm is better, but we follow original implementation. + self.norm1 = ME.MinkowskiInstanceNorm(self.inplanes) + self.relu = ME.MinkowskiReLU(inplace=True) + if self.pool: + self.maxpool = ME.MinkowskiMaxPooling( + kernel_size=2, stride=2, dimension=3) + + for i, num_blocks in enumerate(stage_blocks): + setattr( + self, f'layer{i}', + self._make_layer(block, 64 * 2**i, stage_blocks[i], stride=2)) + + def init_weights(self): + for m in self.modules(): + if isinstance(m, ME.MinkowskiConvolution): + ME.utils.kaiming_normal_( + m.kernel, mode='fan_out', nonlinearity='relu') + + if isinstance(m, ME.MinkowskiBatchNorm): + nn.init.constant_(m.bn.weight, 1) + nn.init.constant_(m.bn.bias, 0) + + def _make_layer(self, block, planes, blocks, stride): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.Sequential( + ME.MinkowskiConvolution( + self.inplanes, + planes * block.expansion, + kernel_size=1, + stride=stride, + dimension=3), + ME.MinkowskiBatchNorm(planes * block.expansion)) + layers = [] + layers.append( + block( + self.inplanes, + planes, + stride=stride, + downsample=downsample, + dimension=3)) + self.inplanes = planes * block.expansion + for i in range(1, blocks): + layers.append(block(self.inplanes, planes, stride=1, dimension=3)) + return nn.Sequential(*layers) + + def forward(self, x): + """Forward pass of ResNet. + + Args: + x (ME.SparseTensor): Input sparse tensor. + + Returns: + list[ME.SparseTensor]: Output sparse tensors. + """ + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + if self.pool: + x = self.maxpool(x) + outs = [] + for i in range(self.num_stages): + x = getattr(self, f'layer{i}')(x) + outs.append(x) + return outs diff --git a/tests/test_models/test_backbones.py b/tests/test_models/test_backbones.py index 392e0ec4c0..c7550448cd 100644 --- a/tests/test_models/test_backbones.py +++ b/tests/test_models/test_backbones.py @@ -353,3 +353,55 @@ def test_dla_net(): assert results[3].shape == torch.Size([4, 128, 4, 4]) assert results[4].shape == torch.Size([4, 256, 2, 2]) assert results[5].shape == torch.Size([4, 512, 1, 1]) + + +def test_mink_resnet(): + if not torch.cuda.is_available(): + pytest.skip('test requires GPU and torch+cuda') + + try: + import MinkowskiEngine as ME + except ImportError: + pytest.skip('test requires MinkowskiEngine installation') + + coordinates, features = [], [] + np.random.seed(42) + # batch of 2 point clouds + for i in range(2): + c = torch.from_numpy(np.random.rand(500, 3) * 100) + coordinates.append(c.float().cuda()) + f = torch.from_numpy(np.random.rand(500, 3)) + features.append(f.float().cuda()) + tensor_coordinates, tensor_features = ME.utils.sparse_collate( + coordinates, features) + x = ME.SparseTensor( + features=tensor_features, coordinates=tensor_coordinates) + + # MinkResNet34 with 4 outputs + cfg = dict(type='MinkResNet', depth=34, in_channels=3) + self = build_backbone(cfg).cuda() + self.init_weights() + + y = self(x) + assert len(y) == 4 + assert y[0].F.shape == torch.Size([900, 64]) + assert y[0].tensor_stride[0] == 8 + assert y[1].F.shape == torch.Size([472, 128]) + assert y[1].tensor_stride[0] == 16 + assert y[2].F.shape == torch.Size([105, 256]) + assert y[2].tensor_stride[0] == 32 + assert y[3].F.shape == torch.Size([16, 512]) + assert y[3].tensor_stride[0] == 64 + + # MinkResNet50 with 2 outputs + cfg = dict( + type='MinkResNet', depth=34, in_channels=3, num_stages=2, pool=False) + self = build_backbone(cfg).cuda() + self.init_weights() + + y = self(x) + assert len(y) == 2 + assert y[0].F.shape == torch.Size([985, 64]) + assert y[0].tensor_stride[0] == 4 + assert y[1].F.shape == torch.Size([900, 128]) + assert y[1].tensor_stride[0] == 8