Skip to content

Commit 8abbe22

Browse files
davnov134facebook-github-bot
authored andcommitted
ICP - point-to-point version
Summary: The iterative closest point algorithm - point-to-point version. Output of `bm_iterative_closest_point`: Argument key: `batch_size dim n_points_X n_points_Y use_pointclouds` ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- IterativeClosestPoint_1_3_100_100_False 107569 111323 5 IterativeClosestPoint_1_3_100_1000_False 118972 122306 5 IterativeClosestPoint_1_3_1000_100_False 108576 110978 5 IterativeClosestPoint_1_3_1000_1000_False 331836 333515 2 IterativeClosestPoint_1_20_100_100_False 134387 137842 4 IterativeClosestPoint_1_20_100_1000_False 149218 153405 4 IterativeClosestPoint_1_20_1000_100_False 414248 416595 2 IterativeClosestPoint_1_20_1000_1000_False 374318 374662 2 IterativeClosestPoint_10_3_100_100_False 539852 539852 1 IterativeClosestPoint_10_3_100_1000_False 752784 752784 1 IterativeClosestPoint_10_3_1000_100_False 1070700 1070700 1 IterativeClosestPoint_10_3_1000_1000_False 1164020 1164020 1 IterativeClosestPoint_10_20_100_100_False 374548 377337 2 IterativeClosestPoint_10_20_100_1000_False 472764 476685 2 IterativeClosestPoint_10_20_1000_100_False 1457175 1457175 1 IterativeClosestPoint_10_20_1000_1000_False 2195820 2195820 1 IterativeClosestPoint_1_3_100_100_True 110084 115824 5 IterativeClosestPoint_1_3_100_1000_True 142728 147696 4 IterativeClosestPoint_1_3_1000_100_True 212966 213966 3 IterativeClosestPoint_1_3_1000_1000_True 369130 375114 2 IterativeClosestPoint_10_3_100_100_True 354615 355179 2 IterativeClosestPoint_10_3_100_1000_True 451815 452704 2 IterativeClosestPoint_10_3_1000_100_True 511833 511833 1 IterativeClosestPoint_10_3_1000_1000_True 798453 798453 1 -------------------------------------------------------------------------------- ``` Reviewed By: shapovalov, gkioxari Differential Revision: D19909952 fbshipit-source-id: f77fadc88fb7c53999909d594114b182ee2a3def
1 parent b5eb33b commit 8abbe22

File tree

6 files changed

+603
-45
lines changed

6 files changed

+603
-45
lines changed

pytorch3d/ops/__init__.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66
from .knn import knn_gather, knn_points
77
from .mesh_face_areas_normals import mesh_face_areas_normals
88
from .packed_to_padded import packed_to_padded, padded_to_packed
9-
from .points_alignment import corresponding_points_alignment
9+
from .points_alignment import corresponding_points_alignment, iterative_closest_point
1010
from .sample_points_from_meshes import sample_points_from_meshes
1111
from .subdivide_meshes import SubdivideMeshes
12+
from .utils import convert_pointclouds_to_tensor, eyes, is_pointclouds, wmean
1213
from .vert_align import vert_align
1314

1415

pytorch3d/ops/points_alignment.py

+249-38
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,231 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

33
import warnings
4-
from typing import List, Tuple, Union
4+
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Union
55

66
import torch
7-
from pytorch3d.ops import utils as oputil
7+
from pytorch3d.ops import knn_points
88
from pytorch3d.structures import utils as strutil
9-
from pytorch3d.structures.pointclouds import Pointclouds
9+
10+
from . import utils as oputil
11+
12+
13+
if TYPE_CHECKING:
14+
from pytorch3d.structures.pointclouds import Pointclouds
15+
16+
17+
# named tuples for inputs/outputs
18+
class SimilarityTransform(NamedTuple):
19+
R: torch.Tensor
20+
T: torch.Tensor
21+
s: torch.Tensor
22+
23+
24+
class ICPSolution(NamedTuple):
25+
converged: bool
26+
rmse: Union[torch.Tensor, None]
27+
Xt: torch.Tensor
28+
RTs: SimilarityTransform
29+
t_history: List[SimilarityTransform]
30+
31+
32+
def iterative_closest_point(
33+
X: Union[torch.Tensor, "Pointclouds"],
34+
Y: Union[torch.Tensor, "Pointclouds"],
35+
init_transform: Optional[SimilarityTransform] = None,
36+
max_iterations: int = 100,
37+
relative_rmse_thr: float = 1e-6,
38+
estimate_scale: bool = False,
39+
allow_reflection: bool = False,
40+
verbose: bool = False,
41+
) -> ICPSolution:
42+
"""
43+
Executes the iterative closest point (ICP) algorithm [1, 2] in order to find
44+
a similarity transformation (rotation `R`, translation `T`, and
45+
optionally scale `s`) between two given differently-sized sets of
46+
`d`-dimensional points `X` and `Y`, such that:
47+
48+
`s[i] X[i] R[i] + T[i] = Y[NN[i]]`,
49+
50+
for all batch indices `i` in the least squares sense. Here, Y[NN[i]] stands
51+
for the indices of nearest neighbors from `Y` to each point in `X`.
52+
Note, however, that the solution is only a local optimum.
53+
54+
Args:
55+
**X**: Batch of `d`-dimensional points
56+
of shape `(minibatch, num_points_X, d)` or a `Pointclouds` object.
57+
**Y**: Batch of `d`-dimensional points
58+
of shape `(minibatch, num_points_Y, d)` or a `Pointclouds` object.
59+
**init_transform**: A named-tuple `SimilarityTransform` of tensors
60+
`R`, `T, `s`, where `R` is a batch of orthonormal matrices of
61+
shape `(minibatch, d, d)`, `T` is a batch of translations
62+
of shape `(minibatch, d)` and `s` is a batch of scaling factors
63+
of shape `(minibatch,)`.
64+
**max_iterations**: The maximum number of ICP iterations.
65+
**relative_rmse_thr**: A threshold on the relative root mean squared error
66+
used to terminate the algorithm.
67+
**estimate_scale**: If `True`, also estimates a scaling component `s`
68+
of the transformation. Otherwise assumes the identity
69+
scale and returns a tensor of ones.
70+
**allow_reflection**: If `True`, allows the algorithm to return `R`
71+
which is orthonormal but has determinant==-1.
72+
**verbose**: If `True`, prints status messages during each ICP iteration.
73+
74+
Returns:
75+
A named tuple `ICPSolution` with the following fields:
76+
**converged**: A boolean flag denoting whether the algorithm converged
77+
successfully (=`True`) or not (=`False`).
78+
**rmse**: Attained root mean squared error after termination of ICP.
79+
**Xt**: The point cloud `X` transformed with the final transformation
80+
(`R`, `T`, `s`). If `X` is a `Pointclouds` object, returns an
81+
instance of `Pointclouds`, otherwise returns `torch.Tensor`.
82+
**RTs**: A named tuple `SimilarityTransform` containing
83+
a batch of similarity transforms with fields:
84+
**R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
85+
**T**: Batch of translations of shape `(minibatch, d)`.
86+
**s**: batch of scaling factors of shape `(minibatch, )`.
87+
**t_history**: A list of named tuples `SimilarityTransform`
88+
the transformation parameters after each ICP iteration.
89+
90+
References:
91+
[1] Besl & McKay: A Method for Registration of 3-D Shapes. TPAMI, 1992.
92+
[2] https://en.wikipedia.org/wiki/Iterative_closest_point
93+
"""
94+
95+
# make sure we convert input Pointclouds structures to
96+
# padded tensors of shape (N, P, 3)
97+
Xt, num_points_X = oputil.convert_pointclouds_to_tensor(X)
98+
Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y)
99+
100+
b, size_X, dim = Xt.shape
101+
102+
if (Xt.shape[2] != Yt.shape[2]) or (Xt.shape[0] != Yt.shape[0]):
103+
raise ValueError(
104+
"Point sets X and Y have to have the same "
105+
+ "number of batches and data dimensions."
106+
)
107+
108+
if ((num_points_Y < Yt.shape[1]).any() or (num_points_X < Xt.shape[1]).any()) and (
109+
num_points_Y != num_points_X
110+
).any():
111+
# we have a heterogeneous input (e.g. because X/Y is
112+
# an instance of Pointclouds)
113+
mask_X = (
114+
torch.arange(size_X, dtype=torch.int64, device=Xt.device)[None]
115+
< num_points_X[:, None]
116+
).type_as(Xt)
117+
else:
118+
mask_X = Xt.new_ones(b, size_X)
119+
120+
# clone the initial point cloud
121+
Xt_init = Xt.clone()
122+
123+
if init_transform is not None:
124+
# parse the initial transform from the input and apply to Xt
125+
try:
126+
R, T, s = init_transform
127+
assert (
128+
R.shape == torch.Size((b, dim, dim))
129+
and T.shape == torch.Size((b, dim))
130+
and s.shape == torch.Size((b,))
131+
)
132+
except Exception:
133+
raise ValueError(
134+
"The initial transformation init_transform has to be "
135+
"a named tuple SimilarityTransform with elements (R, T, s). "
136+
"R are dim x dim orthonormal matrices of shape "
137+
"(minibatch, dim, dim), T is a batch of dim-dimensional "
138+
"translations of shape (minibatch, dim) and s is a batch "
139+
"of scalars of shape (minibatch,)."
140+
)
141+
# apply the init transform to the input point cloud
142+
Xt = _apply_similarity_transform(Xt, R, T, s)
143+
else:
144+
# initialize the transformation with identity
145+
R = oputil.eyes(dim, b, device=Xt.device, dtype=Xt.dtype)
146+
T = Xt.new_zeros((b, dim))
147+
s = Xt.new_ones(b)
148+
149+
prev_rmse = None
150+
rmse = None
151+
iteration = -1
152+
converged = False
153+
154+
# initialize the transformation history
155+
t_history = []
156+
157+
# the main loop over ICP iterations
158+
for iteration in range(max_iterations):
159+
Xt_nn_points = knn_points(
160+
Xt, Yt, lengths1=num_points_X, lengths2=num_points_Y, K=1, return_nn=True
161+
)[2][:, :, 0, :]
162+
163+
# get the alignment of the nearest neighbors from Yt with Xt_init
164+
R, T, s = corresponding_points_alignment(
165+
Xt_init,
166+
Xt_nn_points,
167+
weights=mask_X,
168+
estimate_scale=estimate_scale,
169+
allow_reflection=allow_reflection,
170+
)
171+
172+
# apply the estimated similarity transform to Xt_init
173+
Xt = _apply_similarity_transform(Xt_init, R, T, s)
174+
175+
# add the current transformation to the history
176+
t_history.append(SimilarityTransform(R, T, s))
177+
178+
# compute the root mean squared error
179+
Xt_sq_diff = ((Xt - Xt_nn_points) ** 2).sum(2)
180+
rmse = oputil.wmean(Xt_sq_diff[:, :, None], mask_X).sqrt()[:, 0, 0]
181+
182+
# compute the relative rmse
183+
if prev_rmse is None:
184+
relative_rmse = rmse.new_ones(b)
185+
else:
186+
relative_rmse = (prev_rmse - rmse) / prev_rmse
187+
188+
if verbose:
189+
rmse_msg = (
190+
f"ICP iteration {iteration}: mean/max rmse = "
191+
+ f"{rmse.mean():1.2e}/{rmse.max():1.2e} "
192+
+ f"; mean relative rmse = {relative_rmse.mean():1.2e}"
193+
)
194+
print(rmse_msg)
195+
196+
# check for convergence
197+
if (relative_rmse <= relative_rmse_thr).all():
198+
converged = True
199+
break
200+
201+
# update the previous rmse
202+
prev_rmse = rmse
203+
204+
if verbose:
205+
if converged:
206+
print(f"ICP has converged in {iteration + 1} iterations.")
207+
else:
208+
print(f"ICP has not converged in {max_iterations} iterations.")
209+
210+
if oputil.is_pointclouds(X):
211+
Xt = X.update_padded(Xt) # type: ignore
212+
213+
return ICPSolution(converged, rmse, Xt, SimilarityTransform(R, T, s), t_history)
214+
215+
216+
# threshold for checking that point crosscorelation
217+
# is full rank in corresponding_points_alignment
218+
AMBIGUOUS_ROT_SINGULAR_THR = 1e-15
10219

11220

12221
def corresponding_points_alignment(
13-
X: Union[torch.Tensor, Pointclouds],
14-
Y: Union[torch.Tensor, Pointclouds],
222+
X: Union[torch.Tensor, "Pointclouds"],
223+
Y: Union[torch.Tensor, "Pointclouds"],
15224
weights: Union[torch.Tensor, List[torch.Tensor], None] = None,
16225
estimate_scale: bool = False,
17226
allow_reflection: bool = False,
18-
eps: float = 1e-8,
19-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
227+
eps: float = 1e-9,
228+
) -> SimilarityTransform:
20229
"""
21230
Finds a similarity transformation (rotation `R`, translation `T`
22231
and optionally scale `s`) between two given sets of corresponding
@@ -29,25 +238,25 @@ def corresponding_points_alignment(
29238
The algorithm is also known as Umeyama [1].
30239
31240
Args:
32-
X: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
241+
**X**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
33242
or a `Pointclouds` object.
34-
Y: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
243+
**Y**: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
35244
or a `Pointclouds` object.
36-
weights: Batch of non-negative weights of
245+
**weights**: Batch of non-negative weights of
37246
shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional
38247
tensors that may have different shapes; in that case, the length of
39248
i-th tensor should be equal to the number of points in X_i and Y_i.
40249
Passing `None` means uniform weights.
41-
estimate_scale: If `True`, also estimates a scaling component `s`
250+
**estimate_scale**: If `True`, also estimates a scaling component `s`
42251
of the transformation. Otherwise assumes an identity
43252
scale and returns a tensor of ones.
44-
allow_reflection: If `True`, allows the algorithm to return `R`
253+
**allow_reflection**: If `True`, allows the algorithm to return `R`
45254
which is orthonormal but has determinant==-1.
46-
eps: A scalar for clamping to avoid dividing by zero. Active for the
255+
**eps**: A scalar for clamping to avoid dividing by zero. Active for the
47256
code that estimates the output scale `s`.
48257
49258
Returns:
50-
3-element tuple containing
259+
3-element named tuple `SimilarityTransform` containing
51260
- **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
52261
- **T**: Batch of translations of shape `(minibatch, d)`.
53262
- **s**: batch of scaling factors of shape `(minibatch, )`.
@@ -58,8 +267,8 @@ def corresponding_points_alignment(
58267
"""
59268

60269
# make sure we convert input Pointclouds structures to tensors
61-
Xt, num_points = _convert_point_cloud_to_tensor(X)
62-
Yt, num_points_Y = _convert_point_cloud_to_tensor(Y)
270+
Xt, num_points = oputil.convert_pointclouds_to_tensor(X)
271+
Yt, num_points_Y = oputil.convert_pointclouds_to_tensor(Y)
63272

64273
if (Xt.shape != Yt.shape) or (num_points != num_points_Y).any():
65274
raise ValueError(
@@ -90,8 +299,8 @@ def corresponding_points_alignment(
90299
weights = mask if weights is None else mask * weights.type_as(Xt)
91300

92301
# compute the centroids of the point sets
93-
Xmu = oputil.wmean(Xt, weights, eps=eps)
94-
Ymu = oputil.wmean(Yt, weights, eps=eps)
302+
Xmu = oputil.wmean(Xt, weight=weights, eps=eps)
303+
Ymu = oputil.wmean(Yt, weight=weights, eps=eps)
95304

96305
# mean-center the point sets
97306
Xc = Xt - Xmu
@@ -107,7 +316,7 @@ def corresponding_points_alignment(
107316
if (num_points < (dim + 1)).any():
108317
warnings.warn(
109318
"The size of one of the point clouds is <= dim+1. "
110-
+ "corresponding_points_alignment can't return a unique solution."
319+
+ "corresponding_points_alignment cannot return a unique rotation."
111320
)
112321

113322
# compute the covariance XYcov between the point sets Xc, Yc
@@ -117,6 +326,16 @@ def corresponding_points_alignment(
117326
# decompose the covariance matrix XYcov
118327
U, S, V = torch.svd(XYcov)
119328

329+
# catch ambiguous rotation by checking the magnitude of singular values
330+
if (S.abs() <= AMBIGUOUS_ROT_SINGULAR_THR).any() and not (
331+
num_points < (dim + 1)
332+
).any():
333+
warnings.warn(
334+
"Excessively low rank of "
335+
+ "cross-correlation between aligned point clouds. "
336+
+ "corresponding_points_alignment cannot return a unique rotation."
337+
)
338+
120339
# identity matrix used for fixing reflections
121340
E = torch.eye(dim, dtype=XYcov.dtype, device=XYcov.device)[None].repeat(b, 1, 1)
122341

@@ -148,26 +367,18 @@ def corresponding_points_alignment(
148367
# unit scaling since we do not estimate scale
149368
s = T.new_ones(b)
150369

151-
return R, T, s
370+
return SimilarityTransform(R, T, s)
152371

153372

154-
def _convert_point_cloud_to_tensor(pcl: Union[torch.Tensor, Pointclouds]):
373+
def _apply_similarity_transform(
374+
X: torch.Tensor, R: torch.Tensor, T: torch.Tensor, s: torch.Tensor
375+
) -> torch.Tensor:
155376
"""
156-
If `type(pcl)==Pointclouds`, converts a `pcl` object to a
157-
padded representation and returns it together with the number of points
158-
per batch. Otherwise, returns the input itself with the number of points
159-
set to the size of the second dimension of `pcl`.
377+
Applies a similarity transformation parametrized with a batch of orthonormal
378+
matrices `R` of shape `(minibatch, d, d)`, a batch of translations `T`
379+
of shape `(minibatch, d)` and a batch of scaling factors `s`
380+
of shape `(minibatch,)` to a given `d`-dimensional cloud `X`
381+
of shape `(minibatch, num_points, d)`
160382
"""
161-
if isinstance(pcl, Pointclouds):
162-
X = pcl.points_padded()
163-
num_points = pcl.num_points_per_cloud()
164-
elif torch.is_tensor(pcl):
165-
X = pcl
166-
num_points = X.shape[1] * torch.ones(
167-
X.shape[0], device=X.device, dtype=torch.int64
168-
)
169-
else:
170-
raise ValueError(
171-
"The inputs X, Y should be either Pointclouds objects or tensors."
172-
)
173-
return X, num_points
383+
X = s[:, None, None] * torch.bmm(X, R) + T[:, None, :]
384+
return X

0 commit comments

Comments
 (0)