Skip to content

Commit cb7bd33

Browse files
bottlerfacebook-github-bot
authored andcommitted
validate lengths in chamfer and farthest_points
Summary: Fixes #1326 Reviewed By: kjchalup Differential Revision: D39259697 fbshipit-source-id: 51392f4cc4a956165a62901cb115fcefe0e17277
1 parent 6e25fe8 commit cb7bd33

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

pytorch3d/loss/chamfer.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,11 @@ def _handle_pointcloud_input(
4848
if points.ndim != 3:
4949
raise ValueError("Expected points to be of shape (N, P, D)")
5050
X = points
51-
if lengths is not None and (
52-
lengths.ndim != 1 or lengths.shape[0] != X.shape[0]
53-
):
54-
raise ValueError("Expected lengths to be of shape (N,)")
51+
if lengths is not None:
52+
if lengths.ndim != 1 or lengths.shape[0] != X.shape[0]:
53+
raise ValueError("Expected lengths to be of shape (N,)")
54+
if lengths.max() > X.shape[1]:
55+
raise ValueError("A length value was too long")
5556
if lengths is None:
5657
lengths = torch.full(
5758
(X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device

pytorch3d/ops/sample_farthest_points.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,11 @@ def sample_farthest_points(
5656
# Validate inputs
5757
if lengths is None:
5858
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
59-
60-
if lengths.shape != (N,):
61-
raise ValueError("points and lengths must have same batch dimension.")
59+
else:
60+
if lengths.shape != (N,):
61+
raise ValueError("points and lengths must have same batch dimension.")
62+
if lengths.max() > P:
63+
raise ValueError("A value in lengths was too large.")
6264

6365
# TODO: support providing K as a ratio of the total number of points instead of as an int
6466
if isinstance(K, int):
@@ -107,9 +109,11 @@ def sample_farthest_points_naive(
107109
# Validate inputs
108110
if lengths is None:
109111
lengths = torch.full((N,), P, dtype=torch.int64, device=device)
110-
111-
if lengths.shape[0] != N:
112-
raise ValueError("points and lengths must have same batch dimension.")
112+
else:
113+
if lengths.shape != (N,):
114+
raise ValueError("points and lengths must have same batch dimension.")
115+
if lengths.max() > P:
116+
raise ValueError("Invalid lengths.")
113117

114118
# TODO: support providing K as a ratio of the total number of points instead of as an int
115119
if isinstance(K, int):

0 commit comments

Comments
 (0)