Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support Uncertain L1 Loss #950

Merged
merged 2 commits into from
Sep 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion mmdet3d/models/losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
from .axis_aligned_iou_loss import AxisAlignedIoULoss, axis_aligned_iou_loss
from .chamfer_distance import ChamferDistance, chamfer_distance
from .paconv_regularization_loss import PAConvRegularizationLoss
from .uncertain_smooth_l1_loss import UncertainL1Loss, UncertainSmoothL1Loss

__all__ = [
'FocalLoss', 'SmoothL1Loss', 'binary_cross_entropy', 'ChamferDistance',
'chamfer_distance', 'axis_aligned_iou_loss', 'AxisAlignedIoULoss',
'PAConvRegularizationLoss'
'PAConvRegularizationLoss', 'UncertainL1Loss', 'UncertainSmoothL1Loss'
]
175 changes: 175 additions & 0 deletions mmdet3d/models/losses/uncertain_smooth_l1_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
import torch
from torch import nn as nn

from mmdet.models.builder import LOSSES
from mmdet.models.losses.utils import weighted_loss


@weighted_loss
def uncertain_smooth_l1_loss(pred, target, sigma, alpha=1.0, beta=1.0):
"""Smooth L1 loss with uncertainty.

Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
sigma (torch.Tensor): The sigma for uncertainty.
alpha (float, optional): The coefficient of log(sigma).
Defaults to 1.0.
beta (float, optional): The threshold in the piecewise function.
Defaults to 1.0.

Returns:
torch.Tensor: Calculated loss
"""
assert beta > 0
assert target.numel() > 0
assert pred.size() == target.size() == sigma.size(), 'The size of pred ' \
f'{pred.size()}, target {target.size()}, and sigma {sigma.size()} ' \
'are inconsistent.'
diff = torch.abs(pred - target)
loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
diff - 0.5 * beta)
loss = torch.exp(-sigma) * loss + alpha * sigma

return loss


@weighted_loss
def uncertain_l1_loss(pred, target, sigma, alpha=1.0):
"""L1 loss with uncertainty.

Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
sigma (torch.Tensor): The sigma for uncertainty.
alpha (float, optional): The coefficient of log(sigma).
Defaults to 1.0.

Returns:
torch.Tensor: Calculated loss
"""
assert target.numel() > 0
assert pred.size() == target.size() == sigma.size(), 'The size of pred ' \
f'{pred.size()}, target {target.size()}, and sigma {sigma.size()} ' \
'are inconsistent.'
loss = torch.abs(pred - target)
loss = torch.exp(-sigma) * loss + alpha * sigma
return loss


@LOSSES.register_module()
class UncertainSmoothL1Loss(nn.Module):
r"""Smooth L1 loss with uncertainty.

Please refer to `PGD <https://arxiv.org/abs/2107.14160>`_ and
`Multi-Task Learning Using Uncertainty to Weigh Losses for Scene Geometry
and Semantics <https://arxiv.org/abs/1705.07115>`_ for more details.

Args:
alpha (float, optional): The coefficient of log(sigma).
Defaults to 1.0.
beta (float, optional): The threshold in the piecewise function.
Defaults to 1.0.
reduction (str, optional): The method to reduce the loss.
Options are 'none', 'mean' and 'sum'. Defaults to 'mean'.
loss_weight (float, optional): The weight of loss. Defaults to 1.0
"""

def __init__(self, alpha=1.0, beta=1.0, reduction='mean', loss_weight=1.0):
super(UncertainSmoothL1Loss, self).__init__()
assert reduction in ['none', 'sum', 'mean']
self.alpha = alpha
self.beta = beta
self.reduction = reduction
self.loss_weight = loss_weight

def forward(self,
pred,
target,
sigma,
weight=None,
avg_factor=None,
reduction_override=None,
**kwargs):
"""Forward function.

Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
sigma (torch.Tensor): The sigma for uncertainty.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * uncertain_smooth_l1_loss(
pred,
target,
weight,
sigma=sigma,
alpha=self.alpha,
beta=self.beta,
reduction=reduction,
avg_factor=avg_factor,
**kwargs)
return loss_bbox


@LOSSES.register_module()
class UncertainL1Loss(nn.Module):
"""L1 loss with uncertainty.

Args:
alpha (float, optional): The coefficient of log(sigma).
Defaults to 1.0.
reduction (str, optional): The method to reduce the loss.
Options are 'none', 'mean' and 'sum'. Defaults to 'mean'.
loss_weight (float, optional): The weight of loss. Defaults to 1.0.
"""

def __init__(self, alpha=1.0, reduction='mean', loss_weight=1.0):
super(UncertainL1Loss, self).__init__()
assert reduction in ['none', 'sum', 'mean']
self.alpha = alpha
self.reduction = reduction
self.loss_weight = loss_weight

def forward(self,
pred,
target,
sigma,
weight=None,
avg_factor=None,
reduction_override=None):
"""Forward function.

Args:
pred (torch.Tensor): The prediction.
target (torch.Tensor): The learning target of the prediction.
sigma (torch.Tensor): The sigma for uncertainty.
weight (torch.Tensor, optional): The weight of loss for each
prediction. Defaults to None.
avg_factor (int, optional): Average factor that is used to average
the loss. Defaults to None.
reduction_override (str, optional): The reduction method used to
override the original reduction method of the loss.
Defaults to None.
"""
assert reduction_override in (None, 'none', 'mean', 'sum')
reduction = (
reduction_override if reduction_override else self.reduction)
loss_bbox = self.loss_weight * uncertain_l1_loss(
pred,
target,
weight,
sigma=sigma,
alpha=self.alpha,
reduction=reduction,
avg_factor=avg_factor)
return loss_bbox
37 changes: 37 additions & 0 deletions tests/test_metrics/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
from torch import nn as nn

from mmdet.models import build_loss


def test_chamfer_disrance():
from mmdet3d.models.losses import ChamferDistance, chamfer_distance
Expand Down Expand Up @@ -109,3 +111,38 @@ def __init__(self):
model.modules(), reduction_override='none')
assert none_corr_loss.shape[0] == 3
assert torch.allclose(none_corr_loss.mean(), mean_corr_loss)


def test_uncertain_smooth_l1_loss():
from mmdet3d.models.losses import UncertainL1Loss, UncertainSmoothL1Loss

# reduction shoule be in ['none', 'mean', 'sum']
with pytest.raises(AssertionError):
uncertain_l1_loss = UncertainL1Loss(reduction='l2')
with pytest.raises(AssertionError):
uncertain_smooth_l1_loss = UncertainSmoothL1Loss(reduction='l2')

pred = torch.tensor([1.5783, 0.5972, 1.4821, 0.9488])
target = torch.tensor([1.0813, -0.3466, -1.1404, -0.9665])
sigma = torch.tensor([-1.0053, 0.4710, -1.7784, -0.8603])

# test uncertain l1 loss
uncertain_l1_loss_cfg = dict(
type='UncertainL1Loss', alpha=1.0, reduction='mean', loss_weight=1.0)
uncertain_l1_loss = build_loss(uncertain_l1_loss_cfg)
mean_l1_loss = uncertain_l1_loss(pred, target, sigma)
expected_l1_loss = torch.tensor(4.7069)
assert torch.allclose(mean_l1_loss, expected_l1_loss, atol=1e-4)

# test uncertain smooth l1 loss
uncertain_smooth_l1_loss_cfg = dict(
type='UncertainSmoothL1Loss',
alpha=1.0,
beta=0.5,
reduction='mean',
loss_weight=1.0)
uncertain_smooth_l1_loss = build_loss(uncertain_smooth_l1_loss_cfg)
mean_smooth_l1_loss = uncertain_smooth_l1_loss(pred, target, sigma)
expected_smooth_l1_loss = torch.tensor(3.9795)
assert torch.allclose(
mean_smooth_l1_loss, expected_smooth_l1_loss, atol=1e-4)