Skip to content

Commit e37085d

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Weighted Umeyama.
Summary: 1. Introduced weights to Umeyama implementation. This will be needed for weighted ePnP but is useful on its own. 2. Refactored to use the same code for the Pointclouds mask and passed weights. 3. Added test cases with random weights. 4. Fixed a bug in tests that calls the function with 0 points (fails randomly in Pytorch 1.3, will be fixed in the next release: pytorch/pytorch#31421 ). Reviewed By: gkioxari Differential Revision: D20070293 fbshipit-source-id: e9f549507ef6dcaa0688a0f17342e6d7a9a4336c
1 parent e5b1d6d commit e37085d

File tree

6 files changed

+278
-50
lines changed

6 files changed

+278
-50
lines changed

pytorch3d/ops/points_alignment.py

+47-21
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1-
#!/usr/bin/env python3
21
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
32

43
import warnings
5-
from typing import Tuple, Union
4+
from typing import List, Optional, Tuple, Union
65
import torch
76

87
from pytorch3d.structures.pointclouds import Pointclouds
8+
from pytorch3d.structures import utils as strutil
9+
from pytorch3d.ops import utils as oputil
910

1011

1112
def corresponding_points_alignment(
1213
X: Union[torch.Tensor, Pointclouds],
1314
Y: Union[torch.Tensor, Pointclouds],
15+
weights: Union[torch.Tensor, List[torch.Tensor], None] = None,
1416
estimate_scale: bool = False,
1517
allow_reflection: bool = False,
1618
eps: float = 1e-8,
@@ -28,9 +30,14 @@ def corresponding_points_alignment(
2830
2931
Args:
3032
X: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
31-
or a `Pointclouds` object.
33+
or a `Pointclouds` object.
3234
Y: Batch of `d`-dimensional points of shape `(minibatch, num_point, d)`
33-
or a `Pointclouds` object.
35+
or a `Pointclouds` object.
36+
weights: Batch of non-negative weights of
37+
shape `(minibatch, num_point)` or list of `minibatch` 1-dimensional
38+
tensors that may have different shapes; in that case, the length of
39+
i-th tensor should be equal to the number of points in X_i and Y_i.
40+
Passing `None` means uniform weights.
3441
estimate_scale: If `True`, also estimates a scaling component `s`
3542
of the transformation. Otherwise assumes an identity
3643
scale and returns a tensor of ones.
@@ -59,25 +66,45 @@ def corresponding_points_alignment(
5966
"Point sets X and Y have to have the same \
6067
number of batches, points and dimensions."
6168
)
69+
if weights is not None:
70+
if isinstance(weights, list):
71+
if any(np != w.shape[0] for np, w in zip(num_points, weights)):
72+
raise ValueError(
73+
"number of weights should equal to the "
74+
+ "number of points in the point cloud."
75+
)
76+
weights = [w[..., None] for w in weights]
77+
weights = strutil.list_to_padded(weights)[..., 0]
78+
79+
if Xt.shape[:2] != weights.shape:
80+
raise ValueError(
81+
"weights should have the same first two dimensions as X."
82+
)
6283

6384
b, n, dim = Xt.shape
6485

65-
# compute the centroids of the point sets
66-
Xmu = Xt.sum(1) / torch.clamp(num_points[:, None], 1)
67-
Ymu = Yt.sum(1) / torch.clamp(num_points[:, None], 1)
68-
69-
# mean-center the point sets
70-
Xc = Xt - Xmu[:, None]
71-
Yc = Yt - Ymu[:, None]
72-
7386
if (num_points < Xt.shape[1]).any() or (num_points < Yt.shape[1]).any():
7487
# in case we got Pointclouds as input, mask the unused entries in Xc, Yc
7588
mask = (
76-
torch.arange(n, dtype=torch.int64, device=Xc.device)[None]
89+
torch.arange(n, dtype=torch.int64, device=Xt.device)[None]
7790
< num_points[:, None]
78-
).type_as(Xc)
79-
Xc *= mask[:, :, None]
80-
Yc *= mask[:, :, None]
91+
).type_as(Xt)
92+
weights = mask if weights is None else mask * weights.type_as(Xt)
93+
94+
# compute the centroids of the point sets
95+
Xmu = oputil.wmean(Xt, weights, eps=eps)
96+
Ymu = oputil.wmean(Yt, weights, eps=eps)
97+
98+
# mean-center the point sets
99+
Xc = Xt - Xmu
100+
Yc = Yt - Ymu
101+
102+
total_weight = torch.clamp(num_points, 1)
103+
# special handling for heterogeneous point clouds and/or input weights
104+
if weights is not None:
105+
Xc *= weights[:, :, None]
106+
Yc *= weights[:, :, None]
107+
total_weight = torch.clamp(weights.sum(1), eps)
81108

82109
if (num_points < (dim + 1)).any():
83110
warnings.warn(
@@ -87,7 +114,7 @@ def corresponding_points_alignment(
87114

88115
# compute the covariance XYcov between the point sets Xc, Yc
89116
XYcov = torch.bmm(Xc.transpose(2, 1), Yc)
90-
XYcov = XYcov / torch.clamp(num_points[:, None, None], 1)
117+
XYcov = XYcov / total_weight[:, None, None]
91118

92119
# decompose the covariance matrix XYcov
93120
U, S, V = torch.svd(XYcov)
@@ -111,17 +138,16 @@ def corresponding_points_alignment(
111138
if estimate_scale:
112139
# estimate the scaling component of the transformation
113140
trace_ES = (torch.diagonal(E, dim1=1, dim2=2) * S).sum(1)
114-
Xcov = (Xc * Xc).sum((1, 2)) / torch.clamp(num_points, 1)
141+
Xcov = (Xc * Xc).sum((1, 2)) / total_weight
115142

116143
# the scaling component
117144
s = trace_ES / torch.clamp(Xcov, eps)
118145

119146
# translation component
120-
T = Ymu - s[:, None] * torch.bmm(Xmu[:, None], R)[:, 0, :]
121-
147+
T = Ymu[:, 0, :] - s[:, None] * torch.bmm(Xmu, R)[:, 0, :]
122148
else:
123149
# translation component
124-
T = Ymu - torch.bmm(Xmu[:, None], R)[:, 0]
150+
T = Ymu[:, 0, :] - torch.bmm(Xmu, R)[:, 0, :]
125151

126152
# unit scaling since we do not estimate scale
127153
s = T.new_ones(b)

pytorch3d/ops/utils.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
from typing import Optional, Tuple, Union
3+
4+
import torch
5+
6+
7+
def wmean(
8+
x: torch.Tensor,
9+
weight: Optional[torch.Tensor] = None,
10+
dim: Union[int, Tuple[int]] = -2,
11+
keepdim: bool = True,
12+
eps: float = 1e-9,
13+
) -> torch.Tensor:
14+
"""
15+
Finds the mean of the input tensor across the specified dimension.
16+
If the `weight` argument is provided, computes weighted mean.
17+
Args:
18+
x: tensor of shape `(*, D)`, where D is assumed to be spatial;
19+
weights: if given, non-negative tensor of shape `(*,)`. It must be
20+
broadcastable to `x.shape[:-1]`. Note that the weights for
21+
the last (spatial) dimension are assumed same;
22+
dim: dimension(s) in `x` to average over;
23+
keepdim: tells whether to keep the resulting singleton dimension.
24+
eps: minumum clamping value in the denominator.
25+
Returns:
26+
the mean tensor:
27+
* if `weights` is None => `mean(x, dim)`,
28+
* otherwise => `sum(x*w, dim) / max{sum(w, dim), eps}`.
29+
"""
30+
args = dict(dim=dim, keepdim=keepdim)
31+
32+
if weight is None:
33+
return x.mean(**args)
34+
35+
if any(
36+
xd != wd and xd != 1 and wd != 1
37+
for xd, wd in zip(x.shape[-2::-1], weight.shape[::-1])
38+
):
39+
raise ValueError("wmean: weights are not compatible with the tensor")
40+
41+
return (
42+
(x * weight[..., None]).sum(**args)
43+
/ weight[..., None].sum(**args).clamp(eps)
44+
)

tests/bm_points_alignment.py

+1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def bm_corresponding_points_alignment() -> None:
1616
"dim": [3, 20],
1717
"estimate_scale": [True, False],
1818
"n_points": [100, 10000],
19+
"random_weights": [False, True],
1920
"use_pointclouds": [False],
2021
}
2122

tests/common_testing.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
22

3+
from typing import Optional
34

45
import unittest
56

@@ -35,13 +36,15 @@ def assertClose(
3536
*,
3637
rtol: float = 1e-05,
3738
atol: float = 1e-08,
38-
equal_nan: bool = False
39+
equal_nan: bool = False,
40+
msg: Optional[str] = None,
3941
) -> None:
4042
"""
4143
Verify that two tensors or arrays are the same shape and close.
4244
Args:
4345
input, other: two tensors or two arrays.
4446
rtol, atol, equal_nan: as for torch.allclose.
47+
msg: message in case the assertion is violated.
4548
Note:
4649
Optional arguments here are all keyword-only, to avoid confusion
4750
with msg arguments on other assert functions.
@@ -54,5 +57,7 @@ def assertClose(
5457
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
5558
)
5659
else:
57-
close = np.allclose(input, other, rtol=rtol, atol=atol, equal_nan=equal_nan)
58-
self.assertTrue(close)
60+
close = np.allclose(
61+
input, other, rtol=rtol, atol=atol, equal_nan=equal_nan
62+
)
63+
self.assertTrue(close, msg)

tests/test_ops_utils.py

+75
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
import unittest
3+
4+
import numpy as np
5+
import torch
6+
7+
from common_testing import TestCaseMixin
8+
9+
from pytorch3d.ops import utils as oputil
10+
11+
class TestOpsUtils(TestCaseMixin, unittest.TestCase):
12+
def setUp(self) -> None:
13+
super().setUp()
14+
torch.manual_seed(42)
15+
np.random.seed(42)
16+
17+
def test_wmean(self):
18+
device = torch.device("cuda:0")
19+
n_points = 20
20+
21+
x = torch.rand(n_points, 3, device=device)
22+
weight = torch.rand(n_points, device=device)
23+
x_np = x.cpu().data.numpy()
24+
weight_np = weight.cpu().data.numpy()
25+
26+
# test unweighted
27+
mean = oputil.wmean(x, keepdim=False)
28+
mean_gt = np.average(x_np, axis=-2)
29+
self.assertClose(mean.cpu().data.numpy(), mean_gt)
30+
31+
# test weighted
32+
mean = oputil.wmean(x, weight=weight, keepdim=False)
33+
mean_gt = np.average(x_np, axis=-2, weights=weight_np)
34+
self.assertClose(mean.cpu().data.numpy(), mean_gt)
35+
36+
# test keepdim
37+
mean = oputil.wmean(x, weight=weight, keepdim=True)
38+
self.assertClose(mean[0].cpu().data.numpy(), mean_gt)
39+
40+
# test binary weigths
41+
mean = oputil.wmean(x, weight=weight > 0.5, keepdim=False)
42+
mean_gt = np.average(x_np, axis=-2, weights=weight_np > 0.5)
43+
self.assertClose(mean.cpu().data.numpy(), mean_gt)
44+
45+
# test broadcasting
46+
x = torch.rand(10, n_points, 3, device=device)
47+
x_np = x.cpu().data.numpy()
48+
mean = oputil.wmean(x, weight=weight, keepdim=False)
49+
mean_gt = np.average(x_np, axis=-2, weights=weight_np)
50+
self.assertClose(mean.cpu().data.numpy(), mean_gt)
51+
52+
weight = weight[None, None, :].repeat(3, 1, 1)
53+
mean = oputil.wmean(x, weight=weight, keepdim=False)
54+
self.assertClose(mean[0].cpu().data.numpy(), mean_gt)
55+
56+
# test failing broadcasting
57+
weight = torch.rand(x.shape[0], device=device)
58+
with self.assertRaises(ValueError) as context:
59+
oputil.wmean(x, weight=weight, keepdim=False)
60+
self.assertTrue("weights are not compatible" in str(context.exception))
61+
62+
# test dim
63+
weight = torch.rand(x.shape[0], n_points, device=device)
64+
weight_np = np.tile(
65+
weight[:, :, None].cpu().data.numpy(),
66+
(1, 1, x_np.shape[-1]),
67+
)
68+
mean = oputil.wmean(x, dim=0, weight=weight, keepdim=False)
69+
mean_gt = np.average(x_np, axis=0, weights=weight_np)
70+
self.assertClose(mean.cpu().data.numpy(), mean_gt)
71+
72+
# test dim tuple
73+
mean = oputil.wmean(x, dim=(0, 1), weight=weight, keepdim=False)
74+
mean_gt = np.average(x_np, axis=(0, 1), weights=weight_np)
75+
self.assertClose(mean.cpu().data.numpy(), mean_gt)

0 commit comments

Comments
 (0)