Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support MinkowskiEngine with MinkResNet #1422

Merged
merged 5 commits into from
Apr 28, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions docs/en/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,14 @@ You can check the supported CUDA version for precompiled packages on the [PyTorc
`E.g. 1` If you have CUDA 10.1 installed under `/usr/local/cuda` and would like to install
PyTorch 1.5, you need to install the prebuilt PyTorch with CUDA 10.1.

```python
```shell
conda install pytorch==1.5.0 cudatoolkit=10.1 torchvision==0.6.0 -c pytorch
```

`E.g. 2` If you have CUDA 9.2 installed under `/usr/local/cuda` and would like to install
PyTorch 1.3.1., you need to install the prebuilt PyTorch with CUDA 9.2.

```python
```shell
conda install pytorch=1.3.1 cudatoolkit=9.2 torchvision=0.4.2 -c pytorch
```

Expand Down Expand Up @@ -192,6 +192,14 @@ you can install it before installing MMCV.

4. Some dependencies are optional. Simply running `pip install -v -e .` will only install the minimum runtime requirements. To use optional dependencies like `albumentations` and `imagecorruptions` either install them manually with `pip install -r requirements/optional.txt` or specify desired extras when calling `pip` (e.g. `pip install -v -e .[optional]`). Valid keys for the extras field are: `all`, `tests`, `build`, and `optional`.

We also support Minkowski Engine as a sparse convolution backend. If necessary please follow original [installation guide](https://github.com/NVIDIA/MinkowskiEngine#installation) or use `pip`:

```shell
apt-get install -y python3-dev libopenblas-dev
filaPro marked this conversation as resolved.
Show resolved Hide resolved
pip install ninja
pip install -U git+https://github.com/NVIDIA/MinkowskiEngine --install-option="--blas=openblas" -v --no-deps
```

5. The code can not be built for CPU only environment (where CUDA isn't available) for now.

## Another option: Docker Image
Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/models/backbones/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from mmdet.models.backbones import SSDVGG, HRNet, ResNet, ResNetV1d, ResNeXt
from .dgcnn import DGCNNBackbone
from .dla import DLANet
from .mink_resnet import MinkResNet
from .multi_backbone import MultiBackbone
from .nostem_regnet import NoStemRegNet
from .pointnet2_sa_msg import PointNet2SAMSG
Expand All @@ -11,5 +12,5 @@
__all__ = [
'ResNet', 'ResNetV1d', 'ResNeXt', 'SSDVGG', 'HRNet', 'NoStemRegNet',
'SECOND', 'DGCNNBackbone', 'PointNet2SASSG', 'PointNet2SAMSG',
'MultiBackbone', 'DLANet'
'MultiBackbone', 'DLANet', 'MinkResNet'
]
116 changes: 116 additions & 0 deletions mmdet3d/models/backbones/mink_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Follow https://github.com/NVIDIA/MinkowskiEngine/blob/master/examples/resnet.py # noqa
# and mmcv.cnn.ResNet
try:
import MinkowskiEngine as ME
from MinkowskiEngine.modules.resnet_block import BasicBlock, Bottleneck
except ImportError:
import warnings
warnings.warn(
'Please follow `getting_started.md` to install MinkowskiEngine.`')
# blocks are used in the static part of MinkResNet
BasicBlock, Bottleneck = None, None

import torch.nn as nn

from mmdet3d.models.builder import BACKBONES


@BACKBONES.register_module()
class MinkResNet(nn.Module):
r"""Minkowski ResNet backbone. See `4D Spatio-Temporal ConvNets
<https://arxiv.org/abs/1904.08755>`_ for more details.

Args:
depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
in_channels (ont): Number of input channels, 3 for RGB.
num_stages (int, optional): Resnet stages. Default: 4.
pool (bool, optional): Add max pooling after first conv if True.
Default: True.
"""
arch_settings = {
18: (BasicBlock, (2, 2, 2, 2)),
34: (BasicBlock, (3, 4, 6, 3)),
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}

def __init__(self, depth, in_channels, num_stages=4, pool=True):
super(MinkResNet, self).__init__()
if depth not in self.arch_settings:
raise KeyError(f'invalid depth {depth} for resnet')
assert 4 >= num_stages >= 1
block, stage_blocks = self.arch_settings[depth]
stage_blocks = stage_blocks[:num_stages]
self.num_stages = num_stages
self.pool = pool

self.inplanes = 64
self.conv1 = ME.MinkowskiConvolution(
in_channels, self.inplanes, kernel_size=3, stride=2, dimension=3)
# May be BatchNorm is better, but we follow original implementation.
self.norm1 = ME.MinkowskiInstanceNorm(self.inplanes)
self.relu = ME.MinkowskiReLU(inplace=True)
if self.pool:
self.maxpool = ME.MinkowskiMaxPooling(
kernel_size=2, stride=2, dimension=3)

for i, num_blocks in enumerate(stage_blocks):
setattr(
self, f'layer{i}',
self._make_layer(block, 64 * 2**i, stage_blocks[i], stride=2))

def init_weights(self):
for m in self.modules():
if isinstance(m, ME.MinkowskiConvolution):
ME.utils.kaiming_normal_(
m.kernel, mode='fan_out', nonlinearity='relu')

if isinstance(m, ME.MinkowskiBatchNorm):
nn.init.constant_(m.bn.weight, 1)
nn.init.constant_(m.bn.bias, 0)

def _make_layer(self, block, planes, blocks, stride):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
ME.MinkowskiConvolution(
self.inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
dimension=3),
ME.MinkowskiBatchNorm(planes * block.expansion))
layers = []
layers.append(
block(
self.inplanes,
planes,
stride=stride,
downsample=downsample,
dimension=3))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes, stride=1, dimension=3))
return nn.Sequential(*layers)

def forward(self, x):
"""Forward pass of ResNet.

Args:
x (ME.SparseTensor): Input sparse tensor.

Returns:
list[ME.SparseTensor]: Output sparse tensors.
"""
x = self.conv1(x)
x = self.norm1(x)
x = self.relu(x)
if self.pool:
x = self.maxpool(x)
outs = []
for i in range(self.num_stages):
x = getattr(self, f'layer{i}')(x)
outs.append(x)
return outs
52 changes: 52 additions & 0 deletions tests/test_models/test_backbones.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,3 +353,55 @@ def test_dla_net():
assert results[3].shape == torch.Size([4, 128, 4, 4])
assert results[4].shape == torch.Size([4, 256, 2, 2])
assert results[5].shape == torch.Size([4, 512, 1, 1])


def test_mink_resnet():
if not torch.cuda.is_available():
pytest.skip('test requires GPU and torch+cuda')

try:
import MinkowskiEngine as ME
except ImportError:
pytest.skip('test requires MinkowskiEngine installation')

coordinates, features = [], []
np.random.seed(42)
# batch of 2 point clouds
for i in range(2):
c = torch.from_numpy(np.random.rand(500, 3) * 100)
coordinates.append(c.float().cuda())
f = torch.from_numpy(np.random.rand(500, 3))
features.append(f.float().cuda())
tensor_coordinates, tensor_features = ME.utils.sparse_collate(
coordinates, features)
x = ME.SparseTensor(
features=tensor_features, coordinates=tensor_coordinates)

# MinkResNet34 with 4 outputs
cfg = dict(type='MinkResNet', depth=34, in_channels=3)
self = build_backbone(cfg).cuda()
self.init_weights()

y = self(x)
assert len(y) == 4
assert y[0].F.shape == torch.Size([900, 64])
assert y[0].tensor_stride[0] == 8
assert y[1].F.shape == torch.Size([472, 128])
assert y[1].tensor_stride[0] == 16
assert y[2].F.shape == torch.Size([105, 256])
assert y[2].tensor_stride[0] == 32
assert y[3].F.shape == torch.Size([16, 512])
assert y[3].tensor_stride[0] == 64

# MinkResNet50 with 2 outputs
cfg = dict(
type='MinkResNet', depth=34, in_channels=3, num_stages=2, pool=False)
self = build_backbone(cfg).cuda()
self.init_weights()

y = self(x)
assert len(y) == 2
assert y[0].F.shape == torch.Size([985, 64])
assert y[0].tensor_stride[0] == 4
assert y[1].F.shape == torch.Size([900, 128])
assert y[1].tensor_stride[0] == 8