1
- #!/usr/bin/env python3
2
1
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
3
2
4
3
import warnings
5
- from typing import Tuple , Union
4
+ from typing import List , Optional , Tuple , Union
6
5
import torch
7
6
8
7
from pytorch3d .structures .pointclouds import Pointclouds
8
+ from pytorch3d .structures import utils as strutil
9
+ from pytorch3d .ops import utils as oputil
9
10
10
11
11
12
def corresponding_points_alignment (
12
13
X : Union [torch .Tensor , Pointclouds ],
13
14
Y : Union [torch .Tensor , Pointclouds ],
15
+ weights : Union [torch .Tensor , List [torch .Tensor ], None ] = None ,
14
16
estimate_scale : bool = False ,
15
17
allow_reflection : bool = False ,
16
18
eps : float = 1e-8 ,
@@ -28,9 +30,14 @@ def corresponding_points_alignment(
28
30
29
31
Args:
30
32
X: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
31
- or a `Pointclouds` object.
33
+ or a `Pointclouds` object.
32
34
Y: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
33
- or a `Pointclouds` object.
35
+ or a `Pointclouds` object.
36
+ weights: Batch of non-negative weights of
37
+ shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional
38
+ tensors that may have different shapes; in that case, the length of
39
+ i-th tensor should be equal to the number of points in X_i and Y_i.
40
+ Passing `None` means uniform weights.
34
41
estimate_scale: If `True`, also estimates a scaling component `s`
35
42
of the transformation. Otherwise assumes an identity
36
43
scale and returns a tensor of ones.
@@ -59,25 +66,45 @@ def corresponding_points_alignment(
59
66
"Point sets X and Y have to have the same \
60
67
number of batches, points and dimensions."
61
68
)
69
+ if weights is not None :
70
+ if isinstance (weights , list ):
71
+ if any (np != w .shape [0 ] for np , w in zip (num_points , weights )):
72
+ raise ValueError (
73
+ "number of weights should equal to the "
74
+ + "number of points in the point cloud."
75
+ )
76
+ weights = [w [..., None ] for w in weights ]
77
+ weights = strutil .list_to_padded (weights )[..., 0 ]
78
+
79
+ if Xt .shape [:2 ] != weights .shape :
80
+ raise ValueError (
81
+ "weights should have the same first two dimensions as X."
82
+ )
62
83
63
84
b , n , dim = Xt .shape
64
85
65
- # compute the centroids of the point sets
66
- Xmu = Xt .sum (1 ) / torch .clamp (num_points [:, None ], 1 )
67
- Ymu = Yt .sum (1 ) / torch .clamp (num_points [:, None ], 1 )
68
-
69
- # mean-center the point sets
70
- Xc = Xt - Xmu [:, None ]
71
- Yc = Yt - Ymu [:, None ]
72
-
73
86
if (num_points < Xt .shape [1 ]).any () or (num_points < Yt .shape [1 ]).any ():
74
87
# in case we got Pointclouds as input, mask the unused entries in Xc, Yc
75
88
mask = (
76
- torch .arange (n , dtype = torch .int64 , device = Xc .device )[None ]
89
+ torch .arange (n , dtype = torch .int64 , device = Xt .device )[None ]
77
90
< num_points [:, None ]
78
- ).type_as (Xc )
79
- Xc *= mask [:, :, None ]
80
- Yc *= mask [:, :, None ]
91
+ ).type_as (Xt )
92
+ weights = mask if weights is None else mask * weights .type_as (Xt )
93
+
94
+ # compute the centroids of the point sets
95
+ Xmu = oputil .wmean (Xt , weights , eps = eps )
96
+ Ymu = oputil .wmean (Yt , weights , eps = eps )
97
+
98
+ # mean-center the point sets
99
+ Xc = Xt - Xmu
100
+ Yc = Yt - Ymu
101
+
102
+ total_weight = torch .clamp (num_points , 1 )
103
+ # special handling for heterogeneous point clouds and/or input weights
104
+ if weights is not None :
105
+ Xc *= weights [:, :, None ]
106
+ Yc *= weights [:, :, None ]
107
+ total_weight = torch .clamp (weights .sum (1 ), eps )
81
108
82
109
if (num_points < (dim + 1 )).any ():
83
110
warnings .warn (
@@ -87,7 +114,7 @@ def corresponding_points_alignment(
87
114
88
115
# compute the covariance XYcov between the point sets Xc, Yc
89
116
XYcov = torch .bmm (Xc .transpose (2 , 1 ), Yc )
90
- XYcov = XYcov / torch . clamp ( num_points [:, None , None ], 1 )
117
+ XYcov = XYcov / total_weight [:, None , None ]
91
118
92
119
# decompose the covariance matrix XYcov
93
120
U , S , V = torch .svd (XYcov )
@@ -111,17 +138,16 @@ def corresponding_points_alignment(
111
138
if estimate_scale :
112
139
# estimate the scaling component of the transformation
113
140
trace_ES = (torch .diagonal (E , dim1 = 1 , dim2 = 2 ) * S ).sum (1 )
114
- Xcov = (Xc * Xc ).sum ((1 , 2 )) / torch . clamp ( num_points , 1 )
141
+ Xcov = (Xc * Xc ).sum ((1 , 2 )) / total_weight
115
142
116
143
# the scaling component
117
144
s = trace_ES / torch .clamp (Xcov , eps )
118
145
119
146
# translation component
120
- T = Ymu - s [:, None ] * torch .bmm (Xmu [:, None ], R )[:, 0 , :]
121
-
147
+ T = Ymu [:, 0 , :] - s [:, None ] * torch .bmm (Xmu , R )[:, 0 , :]
122
148
else :
123
149
# translation component
124
- T = Ymu - torch .bmm (Xmu [:, None ], R )[:, 0 ]
150
+ T = Ymu [:, 0 , :] - torch .bmm (Xmu , R )[:, 0 , : ]
125
151
126
152
# unit scaling since we do not estimate scale
127
153
s = T .new_ones (b )
0 commit comments