Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] FCOS3D BBox Coder #940

Merged
merged 5 commits into from
Sep 15, 2021
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/_base_/models/fcos3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
loss_centerness=dict(
type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
bbox_coder=dict(type='FCOS3DBBoxCoder', code_size=9),
norm_on_bbox=True,
centerness_on_reg=True,
center_sampling=True,
Expand Down
3 changes: 2 additions & 1 deletion mmdet3d/core/bbox/coders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from .anchor_free_bbox_coder import AnchorFreeBBoxCoder
from .centerpoint_bbox_coders import CenterPointBBoxCoder
from .delta_xyzwhlr_bbox_coder import DeltaXYZWLHRBBoxCoder
from .fcos3d_bbox_coder import FCOS3DBBoxCoder
from .groupfree3d_bbox_coder import GroupFree3DBBoxCoder
from .partial_bin_based_bbox_coder import PartialBinBasedBBoxCoder
from .point_xyzwhlr_bbox_coder import PointXYZWHLRBBoxCoder

__all__ = [
'build_bbox_coder', 'DeltaXYZWLHRBBoxCoder', 'PartialBinBasedBBoxCoder',
'CenterPointBBoxCoder', 'AnchorFreeBBoxCoder', 'GroupFree3DBBoxCoder',
'PointXYZWHLRBBoxCoder'
'PointXYZWHLRBBoxCoder', 'FCOS3DBBoxCoder'
]
127 changes: 127 additions & 0 deletions mmdet3d/core/bbox/coders/fcos3d_bbox_coder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
import numpy as np
import torch

from mmdet.core.bbox import BaseBBoxCoder
from mmdet.core.bbox.builder import BBOX_CODERS
from ..structures import limit_period


@BBOX_CODERS.register_module()
class FCOS3DBBoxCoder(BaseBBoxCoder):
"""Bounding box coder for FCOS3D.

Args:
base_depths (tuple[tuple[float]]): Depth references for decode box
depth. Defaults to None.
base_dims (tuple[tuple[float]]): Dimension references for decode box
dimension. Defaults to None.
code_size (int): The dimension of boxes to be encoded. Defaults to 7.
norm_on_bbox (bool): Whether to apply normalization on the bounding
box 2D attributes. Defaults to True.
"""

def __init__(self,
base_depths=None,
Tai-Wang marked this conversation as resolved.
Show resolved Hide resolved
base_dims=None,
code_size=7,
norm_on_bbox=True):
super(FCOS3DBBoxCoder, self).__init__()
self.base_depths = base_depths
self.base_dims = base_dims
self.bbox_code_size = code_size
self.norm_on_bbox = norm_on_bbox

def encode(self, gt_bboxes_3d, gt_labels_3d, gt_bboxes, gt_labels):
# TODO: refactor the encoder in the FCOS3D and PGD head
pass

def decode(self, bbox_out, scale, stride, training, cls_score=None):
"""Decode regressed results into 3D predictions.

Note that offsets are not transformed to the projected 3D centers.

Args:
bbox_out (torch.Tensor): Raw bounding box predictions in shape
Tai-Wang marked this conversation as resolved.
Show resolved Hide resolved
[N, C, H, W].
scale (tuple[`Scale`]): Learnable scale parameters.
stride (tuple[int]): Stride for a specific feature level.
training (bool): Whether the decoding is in the training
procedure.
cls_score (torch.Tensor): Classification score map for deciding
which base depth or dim is used. Defaults to None.

Returns:
torch.Tensor: Decoded boxes.
"""
# scale the bbox_out of different level
# only apply to offset, depth and size prediction
scale_offset, scale_depth, scale_size = scale[0:3]
Tai-Wang marked this conversation as resolved.
Show resolved Hide resolved

clone_bbox_out = bbox_out.clone()
bbox_out[:, :2] = scale_offset(clone_bbox_out[:, :2]).float()
bbox_out[:, 2] = scale_depth(clone_bbox_out[:, 2]).float()
bbox_out[:, 3:6] = scale_size(clone_bbox_out[:, 3:6]).float()

if self.base_depths is None:
bbox_out[:, 2] = bbox_out[:, 2].exp()
elif len(self.base_depths) == 1: # only single prior
mean = self.base_depths[0][0]
std = self.base_depths[0][1]
bbox_out[:, 2] = mean + bbox_out.clone()[:, 2] * std
else: # multi-class priors
assert len(self.base_depths) == cls_score.shape[1], \
'The number of multi-class depth priors should be equal to ' \
'the number of categories.'
indices = cls_score.max(dim=1)[1]
depth_priors = cls_score.new_tensor(
self.base_depths)[indices, :].permute(0, 3, 1, 2)
mean = depth_priors[:, 0]
std = depth_priors[:, 1]
bbox_out[:, 2] = mean + bbox_out.clone()[:, 2] * std

bbox_out[:, 3:6] = bbox_out[:, 3:6].exp()
if self.base_dims is not None:
assert len(self.base_dims) == cls_score.shape[1], \
'The number of anchor sizes should be equal to the number ' \
'of categories.'
indices = cls_score.max(dim=1)[1]
size_priors = cls_score.new_tensor(
self.base_dims)[indices, :].permute(0, 3, 1, 2)
bbox_out[:, 3:6] = size_priors * bbox_out.clone()[:, 3:6]

assert self.norm_on_bbox is True, 'Setting norm_on_bbox to False '\
'has not been thoroughly tested for FCOS3D.'
if self.norm_on_bbox:
if not training:
Tai-Wang marked this conversation as resolved.
Show resolved Hide resolved
# Note that this line is conducted only when testing
bbox_out[:, :2] *= stride

return bbox_out

@staticmethod
def decode_yaw(bbox_out, centers2d, dir_cls, dir_offset, cam2img):
"""Decode yaw angle and change it from local to global.i.

Args:
bbox_out (torch.Tensor): Bounding box predictions in shape
Tai-Wang marked this conversation as resolved.
Show resolved Hide resolved
[N, C] with yaws to be decoded.
centers2d (torch.Tensor): Projected 3D-center on the image planes
corresponding to the box predictions.
dir_cls (torch.Tensor): Predicted direction classes.
dir_offset (float): Direction offset before dividing all the
directions into several classes.
cam2img (torch.Tensor): Camera intrinsic matrix in shape [4, 4].

Returns:
torch.Tensor: Bounding boxes with decoded yaws.
"""
if bbox_out.shape[0] > 0:
dir_rot = limit_period(bbox_out[..., 6] - dir_offset, 0, np.pi)
bbox_out[
...,
6] = dir_rot + dir_offset + np.pi * dir_cls.to(bbox_out.dtype)

bbox_out[:, 6] = torch.atan2(centers2d[:, 0] - cam2img[0, 2],
cam2img[0, 0]) + bbox_out[:, 6]

return bbox_out
52 changes: 19 additions & 33 deletions mmdet3d/models/dense_heads/fcos_mono3d_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from mmdet3d.core import box3d_multiclass_nms, limit_period, xywhr2xyxyr
from mmdet.core import multi_apply
from mmdet.core.bbox.builder import build_bbox_coder
from mmdet.models.builder import HEADS, build_loss
from .anchor_free_mono3d_head import AnchorFreeMono3DHead

Expand Down Expand Up @@ -73,6 +74,7 @@ def __init__(self,
type='CrossEntropyLoss',
use_sigmoid=True,
loss_weight=1.0),
bbox_coder=dict(type='FCOS3DBBoxCoder', code_size=9),
norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
centerness_branch=(64, ),
init_cfg=None,
Expand All @@ -95,6 +97,8 @@ def __init__(self,
init_cfg=init_cfg,
**kwargs)
self.loss_centerness = build_loss(loss_centerness)
bbox_coder['code_size'] = self.bbox_code_size
self.bbox_coder = build_bbox_coder(bbox_coder)
if init_cfg is None:
self.init_cfg = dict(
type='Normal',
Expand All @@ -110,9 +114,11 @@ def _init_layers(self):
conv_channels=self.centerness_branch,
conv_strides=(1, ) * len(self.centerness_branch))
self.conv_centerness = nn.Conv2d(self.centerness_branch[-1], 1, 1)
self.scale_dim = 3 # only for offset, depth and size regression
self.scales = nn.ModuleList([
nn.ModuleList([Scale(1.0) for _ in range(3)]) for _ in self.strides
]) # only for offset, depth and size regression
nn.ModuleList([Scale(1.0) for _ in range(self.scale_dim)])
for _ in self.strides
])

def forward(self, feats):
"""Forward features from the upstream network.
Expand All @@ -139,7 +145,7 @@ def forward(self, feats):
each is a 4D-tensor, the channel number is num_points * 1.
"""
return multi_apply(self.forward_single, feats, self.scales,
self.strides)
self.strides)[:5]
Tai-Wang marked this conversation as resolved.
Show resolved Hide resolved

def forward_single(self, x, scale, stride):
"""Forward features of a single scale levle.
Expand Down Expand Up @@ -169,26 +175,12 @@ def forward_single(self, x, scale, stride):
for conv_centerness_prev_layer in self.conv_centerness_prev:
clone_cls_feat = conv_centerness_prev_layer(clone_cls_feat)
centerness = self.conv_centerness(clone_cls_feat)
# scale the bbox_pred of different level
# only apply to offset, depth and size prediction
scale_offset, scale_depth, scale_size = scale[0:3]

clone_bbox_pred = bbox_pred.clone()
bbox_pred[:, :2] = scale_offset(clone_bbox_pred[:, :2]).float()
bbox_pred[:, 2] = scale_depth(clone_bbox_pred[:, 2]).float()
bbox_pred[:, 3:6] = scale_size(clone_bbox_pred[:, 3:6]).float()
bbox_pred = self.bbox_coder.decode(bbox_pred, scale, stride,
self.training, cls_score)

bbox_pred[:, 2] = bbox_pred[:, 2].exp()
bbox_pred[:, 3:6] = bbox_pred[:, 3:6].exp() + 1e-6 # avoid size=0

assert self.norm_on_bbox is True, 'Setting norm_on_bbox to False '\
'has not been thoroughly tested for FCOS3D.'
if self.norm_on_bbox:
if not self.training:
# Note that this line is conducted only when testing
bbox_pred[:, :2] *= stride

return cls_score, bbox_pred, dir_cls_pred, attr_pred, centerness
return cls_score, bbox_pred, dir_cls_pred, attr_pred, centerness, \
cls_feat, reg_feat

@staticmethod
def add_sin_difference(boxes1, boxes2):
Expand Down Expand Up @@ -652,19 +644,13 @@ def _get_bboxes_single(self,
mlvl_dir_scores = torch.cat(mlvl_dir_scores)

# change local yaw to global yaw for 3D nms
if mlvl_bboxes.shape[0] > 0:
dir_rot = limit_period(mlvl_bboxes[..., 6] - self.dir_offset, 0,
np.pi)
mlvl_bboxes[..., 6] = (
dir_rot + self.dir_offset +
np.pi * mlvl_dir_scores.to(mlvl_bboxes.dtype))

cam_intrinsic = mlvl_centers2d.new_zeros((4, 4))
cam_intrinsic[:view.shape[0], :view.shape[1]] = \
cam2img = mlvl_centers2d.new_zeros((4, 4))
cam2img[:view.shape[0], :view.shape[1]] = \
mlvl_centers2d.new_tensor(view)
mlvl_bboxes[:, 6] = torch.atan2(
mlvl_centers2d[:, 0] - cam_intrinsic[0, 2],
cam_intrinsic[0, 0]) + mlvl_bboxes[:, 6]
mlvl_bboxes = self.bbox_coder.decode_yaw(mlvl_bboxes, mlvl_centers2d,
mlvl_dir_scores,
self.dir_offset, cam2img)

mlvl_bboxes_for_nms = xywhr2xyxyr(input_meta['box_type_3d'](
mlvl_bboxes, box_dim=self.bbox_code_size,
origin=(0.5, 0.5, 0.5)).bev)
Expand Down
83 changes: 83 additions & 0 deletions tests/test_utils/test_bbox_coders.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.cnn import Scale
from torch import nn as nn

from mmdet3d.core.bbox import DepthInstance3DBoxes, LiDARInstance3DBoxes
from mmdet.core import build_bbox_coder
Expand Down Expand Up @@ -382,3 +384,84 @@ def test_point_xyzwhlr_bbox_coder():
# test decode
bbox3d_out = boxcoder.decode(bbox_target, points, gt_labels_3d)
assert torch.allclose(bbox3d_out, gt_bboxes_3d, atol=1e-4)


def test_fcos3d_bbox_coder():
# test a config without priors
bbox_coder_cfg = dict(
type='FCOS3DBBoxCoder',
base_depths=None,
base_dims=None,
code_size=7,
norm_on_bbox=True)
bbox_coder = build_bbox_coder(bbox_coder_cfg)

# test decode
# [2, 7, 1, 1]
batch_bbox_out = torch.tensor([[[[0.3130]], [[0.7094]], [[0.8743]],
[[0.0570]], [[0.5579]], [[0.1593]],
[[0.4553]]],
[[[0.7758]], [[0.2298]], [[0.3925]],
[[0.6307]], [[0.4377]], [[0.3339]],
[[0.1966]]]])
batch_scale = nn.ModuleList([Scale(1.0) for _ in range(3)])
stride = 2
training = False
cls_score = torch.randn([2, 2, 1, 1]).sigmoid()
decode_bbox_out = bbox_coder.decode(batch_bbox_out, batch_scale, stride,
training, cls_score)

expected_bbox_out = torch.tensor([[[[0.6261]], [[1.4188]], [[2.3971]],
[[1.0586]], [[1.7470]], [[1.1727]],
[[0.4553]]],
[[[1.5516]], [[0.4596]], [[1.4806]],
[[1.8790]], [[1.5492]], [[1.3965]],
[[0.1966]]]])
assert torch.allclose(decode_bbox_out, expected_bbox_out, atol=1e-3)

# test a config with priors
prior_bbox_coder_cfg = dict(
type='FCOS3DBBoxCoder',
base_depths=((28, 13), (25, 12)),
base_dims=((2, 3, 1), (1, 2, 3)),
code_size=7,
norm_on_bbox=True)
prior_bbox_coder = build_bbox_coder(prior_bbox_coder_cfg)

# test decode
batch_bbox_out = torch.tensor([[[[0.3130]], [[0.7094]], [[0.8743]],
[[0.0570]], [[0.5579]], [[0.1593]],
[[0.4553]]],
[[[0.7758]], [[0.2298]], [[0.3925]],
[[0.6307]], [[0.4377]], [[0.3339]],
[[0.1966]]]])
batch_scale = nn.ModuleList([Scale(1.0) for _ in range(3)])
stride = 2
training = False
cls_score = torch.tensor([[[[0.5811]], [[0.6198]]], [[[0.4889]],
[[0.8142]]]])
decode_bbox_out = prior_bbox_coder.decode(batch_bbox_out, batch_scale,
stride, training, cls_score)
expected_bbox_out = torch.tensor([[[[0.6260]], [[1.4188]], [[35.4916]],
[[1.0587]], [[3.4940]], [[3.5181]],
[[0.4553]]],
[[[1.5516]], [[0.4596]], [[29.7100]],
[[1.8789]], [[3.0983]], [[4.1892]],
[[0.1966]]]])
assert torch.allclose(decode_bbox_out, expected_bbox_out, atol=1e-3)

# test decode_yaw
decode_bbox_out = decode_bbox_out.permute(0, 2, 3, 1).view(-1, 7)
batch_centers2d = torch.tensor([[100, 150], [200, 100]])
batch_dir_cls = torch.tensor([0, 1])
dir_offset = 0.7854
cam2img = torch.tensor([[700, 0, 450, 0], [0, 700, 200, 0], [0, 0, 1, 0],
[0, 0, 0, 1]])
decode_bbox_out = prior_bbox_coder.decode_yaw(decode_bbox_out,
batch_centers2d,
batch_dir_cls, dir_offset,
cam2img)
expected_bbox_out = torch.tensor(
[[0.6260, 1.4188, 35.4916, 1.0587, 3.4940, 3.5181, 3.1332],
[1.5516, 0.4596, 29.7100, 1.8789, 3.0983, 4.1892, 6.1368]])
assert torch.allclose(decode_bbox_out, expected_bbox_out, atol=1e-3)