Skip to content

Commit

Permalink
[Feature] Add MonoFlex Head (#1044)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZCMax authored Jan 21, 2022
1 parent 4590418 commit 8538177
Show file tree
Hide file tree
Showing 12 changed files with 1,093 additions and 26 deletions.
26 changes: 13 additions & 13 deletions mmdet3d/core/bbox/coders/monoflex_bbox_coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,24 +81,24 @@ def encode(self, gt_bboxes_3d):
torch.Tensor: Targets of orientations.
"""
local_yaw = gt_bboxes_3d.local_yaw

# encode local yaw (-pi ~ pi) to multibin format
encode_local_yaw = np.zeros(self.num_dir_bins * 2)
encode_local_yaw = local_yaw.new_zeros(
[local_yaw.shape[0], self.num_dir_bins * 2])
bin_size = 2 * np.pi / self.num_dir_bins
margin_size = bin_size * self.bin_margin

bin_centers = self.bin_centers
bin_centers = local_yaw.new_tensor(self.bin_centers)
range_size = bin_size / 2 + margin_size

offsets = local_yaw - bin_centers.unsqueeze(0)
offsets = local_yaw.unsqueeze(1) - bin_centers.unsqueeze(0)
offsets[offsets > np.pi] = offsets[offsets > np.pi] - 2 * np.pi
offsets[offsets < -np.pi] = offsets[offsets < -np.pi] + 2 * np.pi

for i in range(self.num_dir_bins):
offset = offsets[:, i]
inds = abs(offset) < range_size
encode_local_yaw[inds, i] = 1
encode_local_yaw[inds, i + self.num_dir_bins] = offset
encode_local_yaw[inds, i + self.num_dir_bins] = offset[inds]

orientation_target = encode_local_yaw

Expand Down Expand Up @@ -164,7 +164,7 @@ def decode(self, bbox, base_centers2d, labels, downsample_ratio, cam2imgs):
pred_direct_depth_uncertainty = bbox[:, 49:50].squeeze(-1)

# 2 dimension of offsets x keypoints (8 corners + top/bottom center)
pred_keypoints2d = bbox[:, 6:26]
pred_keypoints2d = bbox[:, 6:26].reshape(-1, 10, 2)

# 1 dimension for depth offsets
pred_direct_depth_offsets = bbox[:, 48:49].squeeze(-1)
Expand Down Expand Up @@ -273,11 +273,11 @@ def decode_location(self,
raise NotImplementedError
# (N, 3)
centers2d_img = \
torch.cat(centers2d_img, depths.unsqueeze(-1), dim=1)
torch.cat((centers2d_img, depths.unsqueeze(-1)), dim=1)
# (N, 4, 1)
centers2d_extend = \
torch.cat((centers2d_img, centers2d_img.new_ones(N, 1)),
dim=1).unqueeze(-1)
dim=1).unsqueeze(-1)
locations = torch.matmul(cam2imgs_inv, centers2d_extend).squeeze(-1)

return locations[:, :3]
Expand Down Expand Up @@ -450,15 +450,15 @@ def decode_orientation(self, ori_vector, locations):
local_yaws = orientations
yaws = local_yaws + rays

larger_idx = (yaws > np.pi).nonzero()
small_idx = (yaws < -np.pi).nonzero()
larger_idx = (yaws > np.pi).nonzero(as_tuple=False)
small_idx = (yaws < -np.pi).nonzero(as_tuple=False)
if len(larger_idx) != 0:
yaws[larger_idx] -= 2 * np.pi
if len(small_idx) != 0:
yaws[small_idx] += 2 * np.pi

larger_idx = (local_yaws > np.pi).nonzero()
small_idx = (local_yaws < -np.pi).nonzero()
larger_idx = (local_yaws > np.pi).nonzero(as_tuple=False)
small_idx = (local_yaws < -np.pi).nonzero(as_tuple=False)
if len(larger_idx) != 0:
local_yaws[larger_idx] -= 2 * np.pi
if len(small_idx) != 0:
Expand Down Expand Up @@ -491,7 +491,7 @@ def decode_bboxes2d(self, reg_bboxes2d, base_centers2d):

return bboxes2d

def combine_depths(depth, depth_uncertainty):
def combine_depths(self, depth, depth_uncertainty):
"""Combine all the prediced depths with depth uncertainty.
Args:
Expand Down
11 changes: 7 additions & 4 deletions mmdet3d/core/bbox/structures/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,11 @@ def yaw2local(yaw, loc):
torch.Tensor: local yaw (alpha in kitti).
"""
local_yaw = yaw - torch.atan2(loc[:, 0], loc[:, 2])
while local_yaw > np.pi:
local_yaw -= np.pi * 2
while local_yaw < -np.pi:
local_yaw += np.pi * 2
larger_idx = (local_yaw > np.pi).nonzero(as_tuple=False)
small_idx = (local_yaw < -np.pi).nonzero(as_tuple=False)
if len(larger_idx) != 0:
local_yaw[larger_idx] -= 2 * np.pi
if len(small_idx) != 0:
local_yaw[small_idx] += 2 * np.pi

return local_yaw
6 changes: 4 additions & 2 deletions mmdet3d/core/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .array_converter import ArrayConverter, array_converter
from .gaussian import draw_heatmap_gaussian, gaussian_2d, gaussian_radius
from .gaussian import (draw_heatmap_gaussian, ellip_gaussian2D, gaussian_2d,
gaussian_radius, get_ellip_gaussian_2D)

__all__ = [
'gaussian_2d', 'gaussian_radius', 'draw_heatmap_gaussian',
'ArrayConverter', 'array_converter'
'ArrayConverter', 'array_converter', 'ellip_gaussian2D',
'get_ellip_gaussian_2D'
]
72 changes: 72 additions & 0 deletions mmdet3d/core/utils/gaussian.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,75 @@ def gaussian_radius(det_size, min_overlap=0.5):
sq3 = torch.sqrt(b3**2 - 4 * a3 * c3)
r3 = (b3 + sq3) / 2
return min(r1, r2, r3)


def get_ellip_gaussian_2D(heatmap, center, radius_x, radius_y, k=1):
"""Generate 2D ellipse gaussian heatmap.
Args:
heatmap (Tensor): Input heatmap, the gaussian kernel will cover on
it and maintain the max value.
center (list[int]): Coord of gaussian kernel's center.
radius_x (int): X-axis radius of gaussian kernel.
radius_y (int): Y-axis radius of gaussian kernel.
k (int, optional): Coefficient of gaussian kernel. Default: 1.
Returns:
out_heatmap (Tensor): Updated heatmap covered by gaussian kernel.
"""
diameter_x, diameter_y = 2 * radius_x + 1, 2 * radius_y + 1
gaussian_kernel = ellip_gaussian2D((radius_x, radius_y),
sigma_x=diameter_x / 6,
sigma_y=diameter_y / 6,
dtype=heatmap.dtype,
device=heatmap.device)

x, y = int(center[0]), int(center[1])
height, width = heatmap.shape[0:2]

left, right = min(x, radius_x), min(width - x, radius_x + 1)
top, bottom = min(y, radius_y), min(height - y, radius_y + 1)

masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
masked_gaussian = gaussian_kernel[radius_y - top:radius_y + bottom,
radius_x - left:radius_x + right]
out_heatmap = heatmap
torch.max(
masked_heatmap,
masked_gaussian * k,
out=out_heatmap[y - top:y + bottom, x - left:x + right])

return out_heatmap


def ellip_gaussian2D(radius,
sigma_x,
sigma_y,
dtype=torch.float32,
device='cpu'):
"""Generate 2D ellipse gaussian kernel.
Args:
radius (tuple(int)): Ellipse radius (radius_x, radius_y) of gaussian
kernel.
sigma_x (int): X-axis sigma of gaussian function.
sigma_y (int): Y-axis sigma of gaussian function.
dtype (torch.dtype, optional): Dtype of gaussian tensor.
Default: torch.float32.
device (str, optional): Device of gaussian tensor.
Default: 'cpu'.
Returns:
h (Tensor): Gaussian kernel with a
``(2 * radius_y + 1) * (2 * radius_x + 1)`` shape.
"""
x = torch.arange(
-radius[0], radius[0] + 1, dtype=dtype, device=device).view(1, -1)
y = torch.arange(
-radius[1], radius[1] + 1, dtype=dtype, device=device).view(-1, 1)

h = (-(x * x) / (2 * sigma_x * sigma_x) - (y * y) /
(2 * sigma_y * sigma_y)).exp()
h[h < torch.finfo(h.dtype).eps * h.max()] = 0

return h
4 changes: 3 additions & 1 deletion mmdet3d/models/dense_heads/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .fcos_mono3d_head import FCOSMono3DHead
from .free_anchor3d_head import FreeAnchor3DHead
from .groupfree3d_head import GroupFree3DHead
from .monoflex_head import MonoFlexHead
from .parta2_rpn_head import PartA2RPNHead
from .pgd_head import PGDHead
from .point_rpn_head import PointRPNHead
Expand All @@ -19,5 +20,6 @@
'Anchor3DHead', 'FreeAnchor3DHead', 'PartA2RPNHead', 'VoteHead',
'SSD3DHead', 'BaseConvBboxHead', 'CenterHead', 'ShapeAwareHead',
'BaseMono3DDenseHead', 'AnchorFreeMono3DHead', 'FCOSMono3DHead',
'GroupFree3DHead', 'PointRPNHead', 'SMOKEMono3DHead', 'PGDHead'
'GroupFree3DHead', 'PointRPNHead', 'SMOKEMono3DHead', 'PGDHead',
'MonoFlexHead'
]
Loading

0 comments on commit 8538177

Please sign in to comment.