Skip to content

Commit d84f274

Browse files
haritha-jfacebook-github-bot
authored andcommitted
add None option for chamfer distance point reduction (#1605)
Summary: The `chamfer_distance` function currently allows `"sum"` or `"mean"` reduction, but does not support returning unreduced (per-point) loss terms. Unreduced losses could be useful if the user wishes to inspect individual losses, or perform additional modifications to loss terms before reduction. One example would be implementing a robust kernel over the loss. This PR adds a `None` option to the `point_reduction` parameter, similar to `batch_reduction`. In case of bi-directional chamfer loss, both the forward and backward distances are returned (a tuple of Tensors of shape `[D, N]` is returned). If normals are provided, similar logic applies to normals as well. This PR addresses issue #622. Pull Request resolved: #1605 Reviewed By: jcjohnson Differential Revision: D48313857 Pulled By: bottler fbshipit-source-id: 35c824827a143649b04166c4817449e1341b7fd9
1 parent 099fc06 commit d84f274

File tree

2 files changed

+219
-65
lines changed

2 files changed

+219
-65
lines changed

pytorch3d/loss/chamfer.py

+42-28
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,22 @@
1313

1414

1515
def _validate_chamfer_reduction_inputs(
16-
batch_reduction: Union[str, None], point_reduction: str
16+
batch_reduction: Union[str, None], point_reduction: Union[str, None]
1717
) -> None:
1818
"""Check the requested reductions are valid.
1919
2020
Args:
2121
batch_reduction: Reduction operation to apply for the loss across the
2222
batch, can be one of ["mean", "sum"] or None.
2323
point_reduction: Reduction operation to apply for the loss across the
24-
points, can be one of ["mean", "sum"].
24+
points, can be one of ["mean", "sum"] or None.
2525
"""
2626
if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
2727
raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
28-
if point_reduction not in ["mean", "sum"]:
29-
raise ValueError('point_reduction must be one of ["mean", "sum"]')
28+
if point_reduction is not None and point_reduction not in ["mean", "sum"]:
29+
raise ValueError('point_reduction must be one of ["mean", "sum"] or None')
30+
if point_reduction is None and batch_reduction is not None:
31+
raise ValueError("Batch reduction must be None if point_reduction is None")
3032

3133

3234
def _handle_pointcloud_input(
@@ -77,7 +79,7 @@ def _chamfer_distance_single_direction(
7779
y_normals,
7880
weights,
7981
batch_reduction: Union[str, None],
80-
point_reduction: str,
82+
point_reduction: Union[str, None],
8183
norm: int,
8284
abs_cosine: bool,
8385
):
@@ -130,26 +132,28 @@ def _chamfer_distance_single_direction(
130132

131133
if weights is not None:
132134
cham_norm_x *= weights.view(N, 1)
133-
cham_norm_x = cham_norm_x.sum(1) # (N,)
134135

135-
# Apply point reduction
136-
cham_x = cham_x.sum(1) # (N,)
137-
if point_reduction == "mean":
138-
x_lengths_clamped = x_lengths.clamp(min=1)
139-
cham_x /= x_lengths_clamped
136+
if point_reduction is not None:
137+
# Apply point reduction
138+
cham_x = cham_x.sum(1) # (N,)
140139
if return_normals:
141-
cham_norm_x /= x_lengths_clamped
140+
cham_norm_x = cham_norm_x.sum(1) # (N,)
141+
if point_reduction == "mean":
142+
x_lengths_clamped = x_lengths.clamp(min=1)
143+
cham_x /= x_lengths_clamped
144+
if return_normals:
145+
cham_norm_x /= x_lengths_clamped
142146

143-
if batch_reduction is not None:
144-
# batch_reduction == "sum"
145-
cham_x = cham_x.sum()
146-
if return_normals:
147-
cham_norm_x = cham_norm_x.sum()
148-
if batch_reduction == "mean":
149-
div = weights.sum() if weights is not None else max(N, 1)
150-
cham_x /= div
147+
if batch_reduction is not None:
148+
# batch_reduction == "sum"
149+
cham_x = cham_x.sum()
151150
if return_normals:
152-
cham_norm_x /= div
151+
cham_norm_x = cham_norm_x.sum()
152+
if batch_reduction == "mean":
153+
div = weights.sum() if weights is not None else max(N, 1)
154+
cham_x /= div
155+
if return_normals:
156+
cham_norm_x /= div
153157

154158
cham_dist = cham_x
155159
cham_normals = cham_norm_x if return_normals else None
@@ -165,7 +169,7 @@ def chamfer_distance(
165169
y_normals=None,
166170
weights=None,
167171
batch_reduction: Union[str, None] = "mean",
168-
point_reduction: str = "mean",
172+
point_reduction: Union[str, None] = "mean",
169173
norm: int = 2,
170174
single_directional: bool = False,
171175
abs_cosine: bool = True,
@@ -191,7 +195,7 @@ def chamfer_distance(
191195
batch_reduction: Reduction operation to apply for the loss across the
192196
batch, can be one of ["mean", "sum"] or None.
193197
point_reduction: Reduction operation to apply for the loss across the
194-
points, can be one of ["mean", "sum"].
198+
points, can be one of ["mean", "sum"] or None.
195199
norm: int indicates the norm used for the distance. Supports 1 for L1 and 2 for L2.
196200
single_directional: If False (default), loss comes from both the distance between
197201
each point in x and its nearest neighbor in y and each point in y and its nearest
@@ -206,11 +210,16 @@ def chamfer_distance(
206210
2-element tuple containing
207211
208212
- **loss**: Tensor giving the reduced distance between the pointclouds
209-
in x and the pointclouds in y.
213+
in x and the pointclouds in y. If point_reduction is None, a 2-element
214+
tuple of Tensors containing forward and backward loss terms shaped (N, P1)
215+
and (N, P2) (if single_directional is False) or a Tensor containing loss
216+
terms shaped (N, P1) (if single_directional is True) is returned.
210217
- **loss_normals**: Tensor giving the reduced cosine distance of normals
211218
between pointclouds in x and pointclouds in y. Returns None if
212-
x_normals and y_normals are None.
213-
219+
x_normals and y_normals are None. If point_reduction is None, a 2-element
220+
tuple of Tensors containing forward and backward loss terms shaped (N, P1)
221+
and (N, P2) (if single_directional is False) or a Tensor containing loss
222+
terms shaped (N, P1) (if single_directional is True) is returned.
214223
"""
215224
_validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
216225

@@ -248,7 +257,12 @@ def chamfer_distance(
248257
norm,
249258
abs_cosine,
250259
)
260+
if point_reduction is not None:
261+
return (
262+
cham_x + cham_y,
263+
(cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
264+
)
251265
return (
252-
cham_x + cham_y,
253-
(cham_norm_x + cham_norm_y) if cham_norm_x is not None else None,
266+
(cham_x, cham_y),
267+
(cham_norm_x, cham_norm_y) if cham_norm_x is not None else None,
254268
)

tests/test_chamfer.py

+177-37
Original file line numberDiff line numberDiff line change
@@ -421,9 +421,9 @@ def test_chamfer_pointcloud_object_withnormals(self):
421421
("mean", "mean"),
422422
("sum", None),
423423
("mean", None),
424+
(None, None),
424425
]
425-
for (point_reduction, batch_reduction) in reductions:
426-
426+
for point_reduction, batch_reduction in reductions:
427427
# Reinitialize all the tensors so that the
428428
# backward pass can be computed.
429429
points_normals = TestChamfer.init_pointclouds(
@@ -450,24 +450,52 @@ def test_chamfer_pointcloud_object_withnormals(self):
450450
batch_reduction=batch_reduction,
451451
)
452452

453-
self.assertClose(cham_cloud, cham_tensor)
454-
self.assertClose(norm_cloud, norm_tensor)
455-
self._check_gradients(
456-
cham_tensor,
457-
norm_tensor,
458-
cham_cloud,
459-
norm_cloud,
460-
points_normals.cloud1.points_list(),
461-
points_normals.p1,
462-
points_normals.cloud2.points_list(),
463-
points_normals.p2,
464-
points_normals.cloud1.normals_list(),
465-
points_normals.n1,
466-
points_normals.cloud2.normals_list(),
467-
points_normals.n2,
468-
points_normals.p1_lengths,
469-
points_normals.p2_lengths,
470-
)
453+
if point_reduction is None:
454+
cham_tensor_bidirectional = torch.hstack(
455+
[cham_tensor[0], cham_tensor[1]]
456+
)
457+
norm_tensor_bidirectional = torch.hstack(
458+
[norm_tensor[0], norm_tensor[1]]
459+
)
460+
cham_cloud_bidirectional = torch.hstack([cham_cloud[0], cham_cloud[1]])
461+
norm_cloud_bidirectional = torch.hstack([norm_cloud[0], norm_cloud[1]])
462+
self.assertClose(cham_cloud_bidirectional, cham_tensor_bidirectional)
463+
self.assertClose(norm_cloud_bidirectional, norm_tensor_bidirectional)
464+
self._check_gradients(
465+
cham_tensor_bidirectional,
466+
norm_tensor_bidirectional,
467+
cham_cloud_bidirectional,
468+
norm_cloud_bidirectional,
469+
points_normals.cloud1.points_list(),
470+
points_normals.p1,
471+
points_normals.cloud2.points_list(),
472+
points_normals.p2,
473+
points_normals.cloud1.normals_list(),
474+
points_normals.n1,
475+
points_normals.cloud2.normals_list(),
476+
points_normals.n2,
477+
points_normals.p1_lengths,
478+
points_normals.p2_lengths,
479+
)
480+
else:
481+
self.assertClose(cham_cloud, cham_tensor)
482+
self.assertClose(norm_cloud, norm_tensor)
483+
self._check_gradients(
484+
cham_tensor,
485+
norm_tensor,
486+
cham_cloud,
487+
norm_cloud,
488+
points_normals.cloud1.points_list(),
489+
points_normals.p1,
490+
points_normals.cloud2.points_list(),
491+
points_normals.p2,
492+
points_normals.cloud1.normals_list(),
493+
points_normals.n1,
494+
points_normals.cloud2.normals_list(),
495+
points_normals.n2,
496+
points_normals.p1_lengths,
497+
points_normals.p2_lengths,
498+
)
471499

472500
def test_chamfer_pointcloud_object_nonormals(self):
473501
N = 5
@@ -481,9 +509,9 @@ def test_chamfer_pointcloud_object_nonormals(self):
481509
("mean", "mean"),
482510
("sum", None),
483511
("mean", None),
512+
(None, None),
484513
]
485-
for (point_reduction, batch_reduction) in reductions:
486-
514+
for point_reduction, batch_reduction in reductions:
487515
# Reinitialize all the tensors so that the
488516
# backward pass can be computed.
489517
points_normals = TestChamfer.init_pointclouds(
@@ -508,19 +536,38 @@ def test_chamfer_pointcloud_object_nonormals(self):
508536
batch_reduction=batch_reduction,
509537
)
510538

511-
self.assertClose(cham_cloud, cham_tensor)
512-
self._check_gradients(
513-
cham_tensor,
514-
None,
515-
cham_cloud,
516-
None,
517-
points_normals.cloud1.points_list(),
518-
points_normals.p1,
519-
points_normals.cloud2.points_list(),
520-
points_normals.p2,
521-
lengths1=points_normals.p1_lengths,
522-
lengths2=points_normals.p2_lengths,
523-
)
539+
if point_reduction is None:
540+
cham_tensor_bidirectional = torch.hstack(
541+
[cham_tensor[0], cham_tensor[1]]
542+
)
543+
cham_cloud_bidirectional = torch.hstack([cham_cloud[0], cham_cloud[1]])
544+
self.assertClose(cham_cloud_bidirectional, cham_tensor_bidirectional)
545+
self._check_gradients(
546+
cham_tensor_bidirectional,
547+
None,
548+
cham_cloud_bidirectional,
549+
None,
550+
points_normals.cloud1.points_list(),
551+
points_normals.p1,
552+
points_normals.cloud2.points_list(),
553+
points_normals.p2,
554+
lengths1=points_normals.p1_lengths,
555+
lengths2=points_normals.p2_lengths,
556+
)
557+
else:
558+
self.assertClose(cham_cloud, cham_tensor)
559+
self._check_gradients(
560+
cham_tensor,
561+
None,
562+
cham_cloud,
563+
None,
564+
points_normals.cloud1.points_list(),
565+
points_normals.p1,
566+
points_normals.cloud2.points_list(),
567+
points_normals.p2,
568+
lengths1=points_normals.p1_lengths,
569+
lengths2=points_normals.p2_lengths,
570+
)
524571

525572
def test_chamfer_point_reduction_mean(self):
526573
"""
@@ -707,6 +754,99 @@ def test_single_directional_chamfer_point_reduction_sum(self):
707754
loss, loss_norm, pred_loss_sum, pred_loss_norm_sum, p1, p11, p2, p22
708755
)
709756

757+
def test_chamfer_point_reduction_none(self):
758+
"""
759+
Compare output of vectorized chamfer loss with naive implementation
760+
for point_reduction = None and batch_reduction = None.
761+
"""
762+
N, max_P1, max_P2 = 7, 10, 18
763+
device = get_random_cuda_device()
764+
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
765+
p1 = points_normals.p1
766+
p2 = points_normals.p2
767+
p1_normals = points_normals.n1
768+
p2_normals = points_normals.n2
769+
p11 = p1.detach().clone()
770+
p22 = p2.detach().clone()
771+
p11.requires_grad = True
772+
p22.requires_grad = True
773+
774+
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
775+
p1, p2, x_normals=p1_normals, y_normals=p2_normals
776+
)
777+
778+
# point_reduction = None
779+
loss, loss_norm = chamfer_distance(
780+
p11,
781+
p22,
782+
x_normals=p1_normals,
783+
y_normals=p2_normals,
784+
batch_reduction=None,
785+
point_reduction=None,
786+
)
787+
788+
loss_bidirectional = torch.hstack([loss[0], loss[1]])
789+
pred_loss_bidirectional = torch.hstack([pred_loss[0], pred_loss[1]])
790+
loss_norm_bidirectional = torch.hstack([loss_norm[0], loss_norm[1]])
791+
pred_loss_norm_bidirectional = torch.hstack(
792+
[pred_loss_norm[0], pred_loss_norm[1]]
793+
)
794+
795+
self.assertClose(loss_bidirectional, pred_loss_bidirectional)
796+
self.assertClose(loss_norm_bidirectional, pred_loss_norm_bidirectional)
797+
798+
# Check gradients
799+
self._check_gradients(
800+
loss_bidirectional,
801+
loss_norm_bidirectional,
802+
pred_loss_bidirectional,
803+
pred_loss_norm_bidirectional,
804+
p1,
805+
p11,
806+
p2,
807+
p22,
808+
)
809+
810+
def test_single_direction_chamfer_point_reduction_none(self):
811+
"""
812+
Compare output of vectorized chamfer loss with naive implementation
813+
for point_reduction = None and batch_reduction = None.
814+
"""
815+
N, max_P1, max_P2 = 7, 10, 18
816+
device = get_random_cuda_device()
817+
points_normals = TestChamfer.init_pointclouds(N, max_P1, max_P2, device)
818+
p1 = points_normals.p1
819+
p2 = points_normals.p2
820+
p1_normals = points_normals.n1
821+
p2_normals = points_normals.n2
822+
p11 = p1.detach().clone()
823+
p22 = p2.detach().clone()
824+
p11.requires_grad = True
825+
p22.requires_grad = True
826+
827+
pred_loss, pred_loss_norm = TestChamfer.chamfer_distance_naive(
828+
p1, p2, x_normals=p1_normals, y_normals=p2_normals
829+
)
830+
831+
# point_reduction = None
832+
loss, loss_norm = chamfer_distance(
833+
p11,
834+
p22,
835+
x_normals=p1_normals,
836+
y_normals=p2_normals,
837+
batch_reduction=None,
838+
point_reduction=None,
839+
single_directional=True,
840+
)
841+
842+
self.assertClose(loss, pred_loss[0])
843+
self.assertClose(loss_norm, pred_loss_norm[0])
844+
845+
# Check gradients
846+
self._check_gradients(
847+
loss, loss_norm, pred_loss[0], pred_loss_norm[0], p1, p11, p2, p22
848+
)
849+
710850
def _check_gradients(
711851
self,
712852
loss,
@@ -880,9 +1020,9 @@ def test_chamfer_joint_reduction(self):
8801020
with self.assertRaisesRegex(ValueError, "batch_reduction must be one of"):
8811021
chamfer_distance(p1, p2, weights=weights, batch_reduction="max")
8821022

883-
# Error when point_reduction is not in ["mean", "sum"].
1023+
# Error when point_reduction is not in ["mean", "sum"] or None.
8841024
with self.assertRaisesRegex(ValueError, "point_reduction must be one of"):
885-
chamfer_distance(p1, p2, weights=weights, point_reduction=None)
1025+
chamfer_distance(p1, p2, weights=weights, point_reduction="max")
8861026

8871027
def test_incorrect_weights(self):
8881028
N, P1, P2 = 16, 64, 128

0 commit comments

Comments
 (0)