From 44702fdb4ba0f80e96bee724766c545d4d93509c Mon Sep 17 00:00:00 2001 From: "Rebecca Chen (Python)" Date: Fri, 2 Aug 2024 10:46:07 -0700 Subject: [PATCH] Add "max" point reduction for chamfer distance Summary: * Adds a "max" option for the point_reduction input to the chamfer_distance function. * When combining the x and y directions, maxes the losses instead of summing them when point_reduction="max". * Moves batch reduction to happen after the directions are combined. * Adds test_chamfer_point_reduction_max and test_single_directional_chamfer_point_reduction_max tests. Fixes https://github.com/facebookresearch/pytorch3d/issues/1838 Reviewed By: bottler Differential Revision: D60614661 fbshipit-source-id: 7879816acfda03e945bada951b931d2c522756eb --- pytorch3d/loss/chamfer.py | 88 ++++++++++++++++++++++++--------------- tests/test_chamfer.py | 83 +++++++++++++++++++++++++++++++++++- 2 files changed, 136 insertions(+), 35 deletions(-) diff --git a/pytorch3d/loss/chamfer.py b/pytorch3d/loss/chamfer.py index 67289020c..7d7876690 100644 --- a/pytorch3d/loss/chamfer.py +++ b/pytorch3d/loss/chamfer.py @@ -27,8 +27,10 @@ def _validate_chamfer_reduction_inputs( """ if batch_reduction is not None and batch_reduction not in ["mean", "sum"]: raise ValueError('batch_reduction must be one of ["mean", "sum"] or None') - if point_reduction is not None and point_reduction not in ["mean", "sum"]: - raise ValueError('point_reduction must be one of ["mean", "sum"] or None') + if point_reduction is not None and point_reduction not in ["mean", "sum", "max"]: + raise ValueError( + 'point_reduction must be one of ["mean", "sum", "max"] or None' + ) if point_reduction is None and batch_reduction is not None: raise ValueError("Batch reduction must be None if point_reduction is None") @@ -80,7 +82,6 @@ def _chamfer_distance_single_direction( x_normals, y_normals, weights, - batch_reduction: Union[str, None], point_reduction: Union[str, None], norm: int, abs_cosine: bool, @@ -103,11 +104,6 @@ def _chamfer_distance_single_direction( raise ValueError("weights cannot be negative.") if weights.sum() == 0.0: weights = weights.view(N, 1) - if batch_reduction in ["mean", "sum"]: - return ( - (x.sum((1, 2)) * weights).sum() * 0.0, - (x.sum((1, 2)) * weights).sum() * 0.0, - ) return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0) cham_norm_x = x.new_zeros(()) @@ -135,7 +131,10 @@ def _chamfer_distance_single_direction( if weights is not None: cham_norm_x *= weights.view(N, 1) - if point_reduction is not None: + if point_reduction == "max": + assert not return_normals + cham_x = cham_x.max(1).values # (N,) + elif point_reduction is not None: # Apply point reduction cham_x = cham_x.sum(1) # (N,) if return_normals: @@ -146,22 +145,34 @@ def _chamfer_distance_single_direction( if return_normals: cham_norm_x /= x_lengths_clamped - if batch_reduction is not None: - # batch_reduction == "sum" - cham_x = cham_x.sum() - if return_normals: - cham_norm_x = cham_norm_x.sum() - if batch_reduction == "mean": - div = weights.sum() if weights is not None else max(N, 1) - cham_x /= div - if return_normals: - cham_norm_x /= div - cham_dist = cham_x cham_normals = cham_norm_x if return_normals else None return cham_dist, cham_normals +def _apply_batch_reduction( + cham_x, cham_norm_x, weights, batch_reduction: Union[str, None] +): + if batch_reduction is None: + return (cham_x, cham_norm_x) + # batch_reduction == "sum" + N = cham_x.shape[0] + cham_x = cham_x.sum() + if cham_norm_x is not None: + cham_norm_x = cham_norm_x.sum() + if batch_reduction == "mean": + if weights is None: + div = max(N, 1) + elif weights.sum() == 0.0: + div = 1 + else: + div = weights.sum() + cham_x /= div + if cham_norm_x is not None: + cham_norm_x /= div + return (cham_x, cham_norm_x) + + def chamfer_distance( x, y, @@ -197,7 +208,8 @@ def chamfer_distance( batch_reduction: Reduction operation to apply for the loss across the batch, can be one of ["mean", "sum"] or None. point_reduction: Reduction operation to apply for the loss across the - points, can be one of ["mean", "sum"] or None. + points, can be one of ["mean", "sum", "max"] or None. Using "max" leads to the + Hausdorff distance. norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2. single_directional: If False (default), loss comes from both the distance between each point in x and its nearest neighbor in y and each point in y and its nearest @@ -227,6 +239,10 @@ def chamfer_distance( if not ((norm == 1) or (norm == 2)): raise ValueError("Support for 1 or 2 norm.") + + if point_reduction == "max" and (x_normals is not None or y_normals is not None): + raise ValueError('Normals must be None if point_reduction is "max"') + x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals) y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals) @@ -238,13 +254,13 @@ def chamfer_distance( x_normals, y_normals, weights, - batch_reduction, point_reduction, norm, abs_cosine, ) if single_directional: - return cham_x, cham_norm_x + loss = cham_x + loss_normals = cham_norm_x else: cham_y, cham_norm_y = _chamfer_distance_single_direction( y, @@ -254,17 +270,23 @@ def chamfer_distance( y_normals, x_normals, weights, - batch_reduction, point_reduction, norm, abs_cosine, ) - if point_reduction is not None: - return ( - cham_x + cham_y, - (cham_norm_x + cham_norm_y) if cham_norm_x is not None else None, - ) - return ( - (cham_x, cham_y), - (cham_norm_x, cham_norm_y) if cham_norm_x is not None else None, - ) + if point_reduction == "max": + loss = torch.maximum(cham_x, cham_y) + loss_normals = None + elif point_reduction is not None: + loss = cham_x + cham_y + if cham_norm_x is not None: + loss_normals = cham_norm_x + cham_norm_y + else: + loss_normals = None + else: + loss = (cham_x, cham_y) + if cham_norm_x is not None: + loss_normals = (cham_norm_x, cham_norm_y) + else: + loss_normals = None + return _apply_batch_reduction(loss, loss_normals, weights, batch_reduction) diff --git a/tests/test_chamfer.py b/tests/test_chamfer.py index 142fbcd87..4adc403f1 100644 --- a/tests/test_chamfer.py +++ b/tests/test_chamfer.py @@ -847,6 +847,85 @@ def test_single_direction_chamfer_point_reduction_none(self): loss, loss_norm, pred_loss[0], pred_loss_norm[0], p1, p11, p2, p22 ) + def test_chamfer_point_reduction_max(self): + """ + Compare output of vectorized chamfer loss with naive implementation + for point_reduction = "max" and batch_reduction = None. + """ + N, P1, P2 = 7, 10, 18 + device = get_random_cuda_device() + points_normals = TestChamfer.init_pointclouds(N, P1, P2, device) + p1 = points_normals.p1 + p2 = points_normals.p2 + weights = points_normals.weights + p11 = p1.detach().clone() + p22 = p2.detach().clone() + p11.requires_grad = True + p22.requires_grad = True + + pred_loss, unused_pred_loss_norm = TestChamfer.chamfer_distance_naive( + p1, p2, x_normals=None, y_normals=None + ) + + loss, loss_norm = chamfer_distance( + p11, + p22, + x_normals=None, + y_normals=None, + weights=weights, + batch_reduction=None, + point_reduction="max", + ) + pred_loss_max = torch.maximum( + pred_loss[0].max(1).values, pred_loss[1].max(1).values + ) + pred_loss_max *= weights + self.assertClose(loss, pred_loss_max) + + self.assertIsNone(loss_norm) + + # Check gradients + self._check_gradients(loss, loss_norm, pred_loss_max, None, p1, p11, p2, p22) + + def test_single_directional_chamfer_point_reduction_max(self): + """ + Compare output of vectorized single directional chamfer loss with naive implementation + for point_reduction = "max" and batch_reduction = None. + """ + N, P1, P2 = 7, 10, 18 + device = get_random_cuda_device() + points_normals = TestChamfer.init_pointclouds(N, P1, P2, device) + p1 = points_normals.p1 + p2 = points_normals.p2 + weights = points_normals.weights + p11 = p1.detach().clone() + p22 = p2.detach().clone() + p11.requires_grad = True + p22.requires_grad = True + + pred_loss, unused_pred_loss_norm = TestChamfer.chamfer_distance_naive( + p1, p2, x_normals=None, y_normals=None + ) + + loss, loss_norm = chamfer_distance( + p11, + p22, + x_normals=None, + y_normals=None, + weights=weights, + batch_reduction=None, + point_reduction="max", + single_directional=True, + ) + pred_loss_max = pred_loss[0].max(1).values + pred_loss_max *= weights + self.assertClose(loss, pred_loss_max) + + self.assertIsNone(loss_norm) + + # Check gradients + self._check_gradients(loss, loss_norm, pred_loss_max, None, p1, p11, p2, p22) + def _check_gradients( self, loss, @@ -1020,9 +1099,9 @@ def test_chamfer_joint_reduction(self): with self.assertRaisesRegex(ValueError, "batch_reduction must be one of"): chamfer_distance(p1, p2, weights=weights, batch_reduction="max") - # Error when point_reduction is not in ["mean", "sum"] or None. + # Error when point_reduction is not in ["mean", "sum", "max"] or None. with self.assertRaisesRegex(ValueError, "point_reduction must be one of"): - chamfer_distance(p1, p2, weights=weights, point_reduction="max") + chamfer_distance(p1, p2, weights=weights, point_reduction="min") def test_incorrect_weights(self): N, P1, P2 = 16, 64, 128