1
1
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2
2
3
3
import warnings
4
- from typing import List , Tuple , Union
4
+ from typing import TYPE_CHECKING , List , NamedTuple , Optional , Union
5
5
6
6
import torch
7
- from pytorch3d .ops import utils as oputil
7
+ from pytorch3d .ops import knn_points
8
8
from pytorch3d .structures import utils as strutil
9
- from pytorch3d .structures .pointclouds import Pointclouds
9
+
10
+ from . import utils as oputil
11
+
12
+
13
+ if TYPE_CHECKING :
14
+ from pytorch3d .structures .pointclouds import Pointclouds
15
+
16
+
17
+ # named tuples for inputs/outputs
18
+ class SimilarityTransform (NamedTuple ):
19
+ R : torch .Tensor
20
+ T : torch .Tensor
21
+ s : torch .Tensor
22
+
23
+
24
+ class ICPSolution (NamedTuple ):
25
+ converged : bool
26
+ rmse : Union [torch .Tensor , None ]
27
+ Xt : torch .Tensor
28
+ RTs : SimilarityTransform
29
+ t_history : List [SimilarityTransform ]
30
+
31
+
32
+ def iterative_closest_point (
33
+ X : Union [torch .Tensor , "Pointclouds" ],
34
+ Y : Union [torch .Tensor , "Pointclouds" ],
35
+ init_transform : Optional [SimilarityTransform ] = None ,
36
+ max_iterations : int = 100 ,
37
+ relative_rmse_thr : float = 1e-6 ,
38
+ estimate_scale : bool = False ,
39
+ allow_reflection : bool = False ,
40
+ verbose : bool = False ,
41
+ ) -> ICPSolution :
42
+ """
43
+ Executes the iterative closest point (ICP) algorithm [1, 2] in order to find
44
+ a similarity transformation (rotation `R`, translation `T`, and
45
+ optionally scale `s`) between two given differently-sized sets of
46
+ `d`-dimensional points `X` and `Y`, such that:
47
+
48
+ `s[i] X[i] R[i] + T[i] = Y[NN[i]]`,
49
+
50
+ for all batch indices `i` in the least squares sense. Here, Y[NN[i]] stands
51
+ for the indices of nearest neighbors from `Y` to each point in `X`.
52
+ Note, however, that the solution is only a local optimum.
53
+
54
+ Args:
55
+ **X**: Batch of `d`-dimensional points
56
+ of shape `(minibatch, num_points_X, d)` or a `Pointclouds` object.
57
+ **Y**: Batch of `d`-dimensional points
58
+ of shape `(minibatch, num_points_Y, d)` or a `Pointclouds` object.
59
+ **init_transform**: A named-tuple `SimilarityTransform` of tensors
60
+ `R`, `T, `s`, where `R` is a batch of orthonormal matrices of
61
+ shape `(minibatch, d, d)`, `T` is a batch of translations
62
+ of shape `(minibatch, d)` and `s` is a batch of scaling factors
63
+ of shape `(minibatch,)`.
64
+ **max_iterations**: The maximum number of ICP iterations.
65
+ **relative_rmse_thr**: A threshold on the relative root mean squared error
66
+ used to terminate the algorithm.
67
+ **estimate_scale**: If `True`, also estimates a scaling component `s`
68
+ of the transformation. Otherwise assumes the identity
69
+ scale and returns a tensor of ones.
70
+ **allow_reflection**: If `True`, allows the algorithm to return `R`
71
+ which is orthonormal but has determinant==-1.
72
+ **verbose**: If `True`, prints status messages during each ICP iteration.
73
+
74
+ Returns:
75
+ A named tuple `ICPSolution` with the following fields:
76
+ **converged**: A boolean flag denoting whether the algorithm converged
77
+ successfully (=`True`) or not (=`False`).
78
+ **rmse**: Attained root mean squared error after termination of ICP.
79
+ **Xt**: The point cloud `X` transformed with the final transformation
80
+ (`R`, `T`, `s`). If `X` is a `Pointclouds` object, returns an
81
+ instance of `Pointclouds`, otherwise returns `torch.Tensor`.
82
+ **RTs**: A named tuple `SimilarityTransform` containing
83
+ a batch of similarity transforms with fields:
84
+ **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
85
+ **T**: Batch of translations of shape `(minibatch, d)`.
86
+ **s**: batch of scaling factors of shape `(minibatch, )`.
87
+ **t_history**: A list of named tuples `SimilarityTransform`
88
+ the transformation parameters after each ICP iteration.
89
+
90
+ References:
91
+ [1] Besl & McKay: A Method for Registration of 3-D Shapes. TPAMI, 1992.
92
+ [2] https://en.wikipedia.org/wiki/Iterative_closest_point
93
+ """
94
+
95
+ # make sure we convert input Pointclouds structures to
96
+ # padded tensors of shape (N, P, 3)
97
+ Xt , num_points_X = oputil .convert_pointclouds_to_tensor (X )
98
+ Yt , num_points_Y = oputil .convert_pointclouds_to_tensor (Y )
99
+
100
+ b , size_X , dim = Xt .shape
101
+
102
+ if (Xt .shape [2 ] != Yt .shape [2 ]) or (Xt .shape [0 ] != Yt .shape [0 ]):
103
+ raise ValueError (
104
+ "Point sets X and Y have to have the same "
105
+ + "number of batches and data dimensions."
106
+ )
107
+
108
+ if ((num_points_Y < Yt .shape [1 ]).any () or (num_points_X < Xt .shape [1 ]).any ()) and (
109
+ num_points_Y != num_points_X
110
+ ).any ():
111
+ # we have a heterogeneous input (e.g. because X/Y is
112
+ # an instance of Pointclouds)
113
+ mask_X = (
114
+ torch .arange (size_X , dtype = torch .int64 , device = Xt .device )[None ]
115
+ < num_points_X [:, None ]
116
+ ).type_as (Xt )
117
+ else :
118
+ mask_X = Xt .new_ones (b , size_X )
119
+
120
+ # clone the initial point cloud
121
+ Xt_init = Xt .clone ()
122
+
123
+ if init_transform is not None :
124
+ # parse the initial transform from the input and apply to Xt
125
+ try :
126
+ R , T , s = init_transform
127
+ assert (
128
+ R .shape == torch .Size ((b , dim , dim ))
129
+ and T .shape == torch .Size ((b , dim ))
130
+ and s .shape == torch .Size ((b ,))
131
+ )
132
+ except Exception :
133
+ raise ValueError (
134
+ "The initial transformation init_transform has to be "
135
+ "a named tuple SimilarityTransform with elements (R, T, s). "
136
+ "R are dim x dim orthonormal matrices of shape "
137
+ "(minibatch, dim, dim), T is a batch of dim-dimensional "
138
+ "translations of shape (minibatch, dim) and s is a batch "
139
+ "of scalars of shape (minibatch,)."
140
+ )
141
+ # apply the init transform to the input point cloud
142
+ Xt = _apply_similarity_transform (Xt , R , T , s )
143
+ else :
144
+ # initialize the transformation with identity
145
+ R = oputil .eyes (dim , b , device = Xt .device , dtype = Xt .dtype )
146
+ T = Xt .new_zeros ((b , dim ))
147
+ s = Xt .new_ones (b )
148
+
149
+ prev_rmse = None
150
+ rmse = None
151
+ iteration = - 1
152
+ converged = False
153
+
154
+ # initialize the transformation history
155
+ t_history = []
156
+
157
+ # the main loop over ICP iterations
158
+ for iteration in range (max_iterations ):
159
+ Xt_nn_points = knn_points (
160
+ Xt , Yt , lengths1 = num_points_X , lengths2 = num_points_Y , K = 1 , return_nn = True
161
+ )[2 ][:, :, 0 , :]
162
+
163
+ # get the alignment of the nearest neighbors from Yt with Xt_init
164
+ R , T , s = corresponding_points_alignment (
165
+ Xt_init ,
166
+ Xt_nn_points ,
167
+ weights = mask_X ,
168
+ estimate_scale = estimate_scale ,
169
+ allow_reflection = allow_reflection ,
170
+ )
171
+
172
+ # apply the estimated similarity transform to Xt_init
173
+ Xt = _apply_similarity_transform (Xt_init , R , T , s )
174
+
175
+ # add the current transformation to the history
176
+ t_history .append (SimilarityTransform (R , T , s ))
177
+
178
+ # compute the root mean squared error
179
+ Xt_sq_diff = ((Xt - Xt_nn_points ) ** 2 ).sum (2 )
180
+ rmse = oputil .wmean (Xt_sq_diff [:, :, None ], mask_X ).sqrt ()[:, 0 , 0 ]
181
+
182
+ # compute the relative rmse
183
+ if prev_rmse is None :
184
+ relative_rmse = rmse .new_ones (b )
185
+ else :
186
+ relative_rmse = (prev_rmse - rmse ) / prev_rmse
187
+
188
+ if verbose :
189
+ rmse_msg = (
190
+ f"ICP iteration { iteration } : mean/max rmse = "
191
+ + f"{ rmse .mean ():1.2e} /{ rmse .max ():1.2e} "
192
+ + f"; mean relative rmse = { relative_rmse .mean ():1.2e} "
193
+ )
194
+ print (rmse_msg )
195
+
196
+ # check for convergence
197
+ if (relative_rmse <= relative_rmse_thr ).all ():
198
+ converged = True
199
+ break
200
+
201
+ # update the previous rmse
202
+ prev_rmse = rmse
203
+
204
+ if verbose :
205
+ if converged :
206
+ print (f"ICP has converged in { iteration + 1 } iterations." )
207
+ else :
208
+ print (f"ICP has not converged in { max_iterations } iterations." )
209
+
210
+ if oputil .is_pointclouds (X ):
211
+ Xt = X .update_padded (Xt ) # type: ignore
212
+
213
+ return ICPSolution (converged , rmse , Xt , SimilarityTransform (R , T , s ), t_history )
214
+
215
+
216
+ # threshold for checking that point crosscorelation
217
+ # is full rank in corresponding_points_alignment
218
+ AMBIGUOUS_ROT_SINGULAR_THR = 1e-15
10
219
11
220
12
221
def corresponding_points_alignment (
13
- X : Union [torch .Tensor , Pointclouds ],
14
- Y : Union [torch .Tensor , Pointclouds ],
222
+ X : Union [torch .Tensor , " Pointclouds" ],
223
+ Y : Union [torch .Tensor , " Pointclouds" ],
15
224
weights : Union [torch .Tensor , List [torch .Tensor ], None ] = None ,
16
225
estimate_scale : bool = False ,
17
226
allow_reflection : bool = False ,
18
- eps : float = 1e-8 ,
19
- ) -> Tuple [ torch . Tensor , torch . Tensor , torch . Tensor ] :
227
+ eps : float = 1e-9 ,
228
+ ) -> SimilarityTransform :
20
229
"""
21
230
Finds a similarity transformation (rotation `R`, translation `T`
22
231
and optionally scale `s`) between two given sets of corresponding
@@ -29,25 +238,25 @@ def corresponding_points_alignment(
29
238
The algorithm is also known as Umeyama [1].
30
239
31
240
Args:
32
- X : Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
241
+ **X** : Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
33
242
or a `Pointclouds` object.
34
- Y : Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
243
+ **Y** : Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
35
244
or a `Pointclouds` object.
36
- weights: Batch of non-negative weights of
245
+ ** weights** : Batch of non-negative weights of
37
246
shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional
38
247
tensors that may have different shapes; in that case, the length of
39
248
i-th tensor should be equal to the number of points in X_i and Y_i.
40
249
Passing `None` means uniform weights.
41
- estimate_scale: If `True`, also estimates a scaling component `s`
250
+ ** estimate_scale** : If `True`, also estimates a scaling component `s`
42
251
of the transformation. Otherwise assumes an identity
43
252
scale and returns a tensor of ones.
44
- allow_reflection: If `True`, allows the algorithm to return `R`
253
+ ** allow_reflection** : If `True`, allows the algorithm to return `R`
45
254
which is orthonormal but has determinant==-1.
46
- eps: A scalar for clamping to avoid dividing by zero. Active for the
255
+ ** eps** : A scalar for clamping to avoid dividing by zero. Active for the
47
256
code that estimates the output scale `s`.
48
257
49
258
Returns:
50
- 3-element tuple containing
259
+ 3-element named tuple `SimilarityTransform` containing
51
260
- **R**: Batch of orthonormal matrices of shape `(minibatch, d, d)`.
52
261
- **T**: Batch of translations of shape `(minibatch, d)`.
53
262
- **s**: batch of scaling factors of shape `(minibatch, )`.
@@ -58,8 +267,8 @@ def corresponding_points_alignment(
58
267
"""
59
268
60
269
# make sure we convert input Pointclouds structures to tensors
61
- Xt , num_points = _convert_point_cloud_to_tensor (X )
62
- Yt , num_points_Y = _convert_point_cloud_to_tensor (Y )
270
+ Xt , num_points = oputil . convert_pointclouds_to_tensor (X )
271
+ Yt , num_points_Y = oputil . convert_pointclouds_to_tensor (Y )
63
272
64
273
if (Xt .shape != Yt .shape ) or (num_points != num_points_Y ).any ():
65
274
raise ValueError (
@@ -90,8 +299,8 @@ def corresponding_points_alignment(
90
299
weights = mask if weights is None else mask * weights .type_as (Xt )
91
300
92
301
# compute the centroids of the point sets
93
- Xmu = oputil .wmean (Xt , weights , eps = eps )
94
- Ymu = oputil .wmean (Yt , weights , eps = eps )
302
+ Xmu = oputil .wmean (Xt , weight = weights , eps = eps )
303
+ Ymu = oputil .wmean (Yt , weight = weights , eps = eps )
95
304
96
305
# mean-center the point sets
97
306
Xc = Xt - Xmu
@@ -107,7 +316,7 @@ def corresponding_points_alignment(
107
316
if (num_points < (dim + 1 )).any ():
108
317
warnings .warn (
109
318
"The size of one of the point clouds is <= dim+1. "
110
- + "corresponding_points_alignment can't return a unique solution ."
319
+ + "corresponding_points_alignment cannot return a unique rotation ."
111
320
)
112
321
113
322
# compute the covariance XYcov between the point sets Xc, Yc
@@ -117,6 +326,16 @@ def corresponding_points_alignment(
117
326
# decompose the covariance matrix XYcov
118
327
U , S , V = torch .svd (XYcov )
119
328
329
+ # catch ambiguous rotation by checking the magnitude of singular values
330
+ if (S .abs () <= AMBIGUOUS_ROT_SINGULAR_THR ).any () and not (
331
+ num_points < (dim + 1 )
332
+ ).any ():
333
+ warnings .warn (
334
+ "Excessively low rank of "
335
+ + "cross-correlation between aligned point clouds. "
336
+ + "corresponding_points_alignment cannot return a unique rotation."
337
+ )
338
+
120
339
# identity matrix used for fixing reflections
121
340
E = torch .eye (dim , dtype = XYcov .dtype , device = XYcov .device )[None ].repeat (b , 1 , 1 )
122
341
@@ -148,26 +367,18 @@ def corresponding_points_alignment(
148
367
# unit scaling since we do not estimate scale
149
368
s = T .new_ones (b )
150
369
151
- return R , T , s
370
+ return SimilarityTransform ( R , T , s )
152
371
153
372
154
- def _convert_point_cloud_to_tensor (pcl : Union [torch .Tensor , Pointclouds ]):
373
+ def _apply_similarity_transform (
374
+ X : torch .Tensor , R : torch .Tensor , T : torch .Tensor , s : torch .Tensor
375
+ ) -> torch .Tensor :
155
376
"""
156
- If `type(pcl)==Pointclouds`, converts a `pcl` object to a
157
- padded representation and returns it together with the number of points
158
- per batch. Otherwise, returns the input itself with the number of points
159
- set to the size of the second dimension of `pcl`.
377
+ Applies a similarity transformation parametrized with a batch of orthonormal
378
+ matrices `R` of shape `(minibatch, d, d)`, a batch of translations `T`
379
+ of shape `(minibatch, d)` and a batch of scaling factors `s`
380
+ of shape `(minibatch,)` to a given `d`-dimensional cloud `X`
381
+ of shape `(minibatch, num_points, d)`
160
382
"""
161
- if isinstance (pcl , Pointclouds ):
162
- X = pcl .points_padded ()
163
- num_points = pcl .num_points_per_cloud ()
164
- elif torch .is_tensor (pcl ):
165
- X = pcl
166
- num_points = X .shape [1 ] * torch .ones (
167
- X .shape [0 ], device = X .device , dtype = torch .int64
168
- )
169
- else :
170
- raise ValueError (
171
- "The inputs X, Y should be either Pointclouds objects or tensors."
172
- )
173
- return X , num_points
383
+ X = s [:, None , None ] * torch .bmm (X , R ) + T [:, None , :]
384
+ return X
0 commit comments