Skip to content

Commit

Permalink
Add "max" point reduction for chamfer distance
Browse files Browse the repository at this point in the history
Summary:
* Adds a "max" option for the point_reduction input to the
  chamfer_distance function.
* When combining the x and y directions, maxes the losses instead
  of summing them when point_reduction="max".
* Moves batch reduction to happen after the directions are
  combined.
* Adds test_chamfer_point_reduction_max and
  test_single_directional_chamfer_point_reduction_max tests.

Fixes  #1838

Reviewed By: bottler

Differential Revision: D60614661

fbshipit-source-id: 7879816acfda03e945bada951b931d2c522756eb
  • Loading branch information
rchen152 authored and facebook-github-bot committed Aug 2, 2024
1 parent 7edaee7 commit 44702fd
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 35 deletions.
88 changes: 55 additions & 33 deletions pytorch3d/loss/chamfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ def _validate_chamfer_reduction_inputs(
"""
if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
if point_reduction is not None and point_reduction not in ["mean", "sum"]:
raise ValueError('point_reduction must be one of ["mean", "sum"] or None')
if point_reduction is not None and point_reduction not in ["mean", "sum", "max"]:
raise ValueError(
'point_reduction must be one of ["mean", "sum", "max"] or None'
)
if point_reduction is None and batch_reduction is not None:
raise ValueError("Batch reduction must be None if point_reduction is None")

Expand Down Expand Up @@ -80,7 +82,6 @@ def _chamfer_distance_single_direction(
x_normals,
y_normals,
weights,
batch_reduction: Union[str, None],
point_reduction: Union[str, None],
norm: int,
abs_cosine: bool,
Expand All @@ -103,11 +104,6 @@ def _chamfer_distance_single_direction(
raise ValueError("weights cannot be negative.")
if weights.sum() == 0.0:
weights = weights.view(N, 1)
if batch_reduction in ["mean", "sum"]:
return (
(x.sum((1, 2)) * weights).sum() * 0.0,
(x.sum((1, 2)) * weights).sum() * 0.0,
)
return ((x.sum((1, 2)) * weights) * 0.0, (x.sum((1, 2)) * weights) * 0.0)

cham_norm_x = x.new_zeros(())
Expand Down Expand Up @@ -135,7 +131,10 @@ def _chamfer_distance_single_direction(
if weights is not None:
cham_norm_x *= weights.view(N, 1)

if point_reduction is not None:
if point_reduction == "max":
assert not return_normals
cham_x = cham_x.max(1).values # (N,)
elif point_reduction is not None:
# Apply point reduction
cham_x = cham_x.sum(1) # (N,)
if return_normals:
Expand All @@ -146,22 +145,34 @@ def _chamfer_distance_single_direction(
if return_normals:
cham_norm_x /= x_lengths_clamped

if batch_reduction is not None:
# batch_reduction == "sum"
cham_x = cham_x.sum()
if return_normals:
cham_norm_x = cham_norm_x.sum()
if batch_reduction == "mean":
div = weights.sum() if weights is not None else max(N, 1)
cham_x /= div
if return_normals:
cham_norm_x /= div

cham_dist = cham_x
cham_normals = cham_norm_x if return_normals else None
return cham_dist, cham_normals


def _apply_batch_reduction(
cham_x, cham_norm_x, weights, batch_reduction: Union[str, None]
):
if batch_reduction is None:
return (cham_x, cham_norm_x)
# batch_reduction == "sum"
N = cham_x.shape[0]
cham_x = cham_x.sum()
if cham_norm_x is not None:
cham_norm_x = cham_norm_x.sum()
if batch_reduction == "mean":
if weights is None:
div = max(N, 1)
elif weights.sum() == 0.0:
div = 1
else:
div = weights.sum()
cham_x /= div
if cham_norm_x is not None:
cham_norm_x /= div
return (cham_x, cham_norm_x)


def chamfer_distance(
x,
y,
Expand Down Expand Up @@ -197,7 +208,8 @@ def chamfer_distance(
batch_reduction: Reduction operation to apply for the loss across the
batch, can be one of ["mean", "sum"] or None.
point_reduction: Reduction operation to apply for the loss across the
points, can be one of ["mean", "sum"] or None.
points, can be one of ["mean", "sum", "max"] or None. Using "max" leads to the
Hausdorff distance.
norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
single_directional: If False (default), loss comes from both the distance between
each point in x and its nearest neighbor in y and each point in y and its nearest
Expand Down Expand Up @@ -227,6 +239,10 @@ def chamfer_distance(

if not ((norm == 1) or (norm == 2)):
raise ValueError("Support for 1 or 2 norm.")

if point_reduction == "max" and (x_normals is not None or y_normals is not None):
raise ValueError('Normals must be None if point_reduction is "max"')

x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)

Expand All @@ -238,13 +254,13 @@ def chamfer_distance(
x_normals,
y_normals,
weights,
batch_reduction,
point_reduction,
norm,
abs_cosine,
)
if single_directional:
return cham_x, cham_norm_x
loss = cham_x
loss_normals = cham_norm_x
else:
cham_y, cham_norm_y = _chamfer_distance_single_direction(
y,
Expand All @@ -254,17 +270,23 @@ def chamfer_distance(
y_normals,
x_normals,
weights,
batch_reduction,
point_reduction,
norm,
abs_cosine,
)
if point_reduction is not None:
return (
cham_x + cham_y,
(cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
)
return (
(cham_x, cham_y),
(cham_norm_x, cham_norm_y) if cham_norm_x is not None else None,
)
if point_reduction == "max":
loss = torch.maximum(cham_x, cham_y)
loss_normals = None
elif point_reduction is not None:
loss = cham_x + cham_y
if cham_norm_x is not None:
loss_normals = cham_norm_x + cham_norm_y
else:
loss_normals = None
else:
loss = (cham_x, cham_y)
if cham_norm_x is not None:
loss_normals = (cham_norm_x, cham_norm_y)
else:
loss_normals = None
return _apply_batch_reduction(loss, loss_normals, weights, batch_reduction)
83 changes: 81 additions & 2 deletions tests/test_chamfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,6 +847,85 @@ def test_single_direction_chamfer_point_reduction_none(self):
loss, loss_norm, pred_loss[0], pred_loss_norm[0], p1, p11, p2, p22
)

def test_chamfer_point_reduction_max(self):
"""
Compare output of vectorized chamfer loss with naive implementation
for point_reduction = "max" and batch_reduction = None.
"""
N, P1, P2 = 7, 10, 18
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
weights = points_normals.weights
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True

pred_loss, unused_pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, x_normals=None, y_normals=None
)

loss, loss_norm = chamfer_distance(
p11,
p22,
x_normals=None,
y_normals=None,
weights=weights,
batch_reduction=None,
point_reduction="max",
)
pred_loss_max = torch.maximum(
pred_loss[0].max(1).values, pred_loss[1].max(1).values
)
pred_loss_max *= weights
self.assertClose(loss, pred_loss_max)

self.assertIsNone(loss_norm)

# Check gradients
self._check_gradients(loss, loss_norm, pred_loss_max, None, p1, p11, p2, p22)

def test_single_directional_chamfer_point_reduction_max(self):
"""
Compare output of vectorized single directional chamfer loss with naive implementation
for point_reduction = "max" and batch_reduction = None.
"""
N, P1, P2 = 7, 10, 18
device = get_random_cuda_device()
points_normals = TestChamfer.init_pointclouds(N, P1, P2, device)
p1 = points_normals.p1
p2 = points_normals.p2
weights = points_normals.weights
p11 = p1.detach().clone()
p22 = p2.detach().clone()
p11.requires_grad = True
p22.requires_grad = True

pred_loss, unused_pred_loss_norm = TestChamfer.chamfer_distance_naive(
p1, p2, x_normals=None, y_normals=None
)

loss, loss_norm = chamfer_distance(
p11,
p22,
x_normals=None,
y_normals=None,
weights=weights,
batch_reduction=None,
point_reduction="max",
single_directional=True,
)
pred_loss_max = pred_loss[0].max(1).values
pred_loss_max *= weights
self.assertClose(loss, pred_loss_max)

self.assertIsNone(loss_norm)

# Check gradients
self._check_gradients(loss, loss_norm, pred_loss_max, None, p1, p11, p2, p22)

def _check_gradients(
self,
loss,
Expand Down Expand Up @@ -1020,9 +1099,9 @@ def test_chamfer_joint_reduction(self):
with self.assertRaisesRegex(ValueError, "batch_reduction must be one of"):
chamfer_distance(p1, p2, weights=weights, batch_reduction="max")

# Error when point_reduction is not in ["mean", "sum"] or None.
# Error when point_reduction is not in ["mean", "sum", "max"] or None.
with self.assertRaisesRegex(ValueError, "point_reduction must be one of"):
chamfer_distance(p1, p2, weights=weights, point_reduction="max")
chamfer_distance(p1, p2, weights=weights, point_reduction="min")

def test_incorrect_weights(self):
N, P1, P2 = 16, 64, 128
Expand Down

0 comments on commit 44702fd

Please sign in to comment.