Skip to content

Commit b2ac265

Browse files
davnov134facebook-github-bot
authored andcommitted
SE3 exponential and logarithm maps.
Summary: Implements the SE3 logarithm and exponential maps. (this is a second part of the split of D23326429) Outputs of `bm_se3`: ``` Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- SE3_EXP_1 738 885 678 SE3_EXP_10 717 877 698 SE3_EXP_100 718 847 697 SE3_EXP_1000 729 1181 686 -------------------------------------------------------------------------------- Benchmark Avg Time(μs) Peak Time(μs) Iterations -------------------------------------------------------------------------------- SE3_LOG_1 1451 2267 345 SE3_LOG_10 2185 2453 229 SE3_LOG_100 2217 2448 226 SE3_LOG_1000 2455 2599 204 -------------------------------------------------------------------------------- ``` Reviewed By: patricklabatut Differential Revision: D27852557 fbshipit-source-id: e42ccc9cfffe780e9cad24129de15624ae818472
1 parent 9f14e82 commit b2ac265

File tree

5 files changed

+561
-2
lines changed

5 files changed

+561
-2
lines changed

pytorch3d/transforms/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
rotation_6d_to_matrix,
2121
standardize_quaternion,
2222
)
23+
from .se3 import se3_exp_map, se3_log_map
2324
from .so3 import (
2425
so3_exponential_map,
2526
so3_exp_map,

pytorch3d/transforms/se3.py

+213
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
import torch
4+
5+
from .so3 import hat, _so3_exp_map, so3_log_map
6+
7+
8+
def se3_exp_map(log_transform: torch.Tensor, eps: float = 1e-4) -> torch.Tensor:
9+
"""
10+
Convert a batch of logarithmic representations of SE(3) matrices `log_transform`
11+
to a batch of 4x4 SE(3) matrices using the exponential map.
12+
See e.g. [1], Sec 9.4.2. for more detailed description.
13+
14+
A SE(3) matrix has the following form:
15+
```
16+
[ R 0 ]
17+
[ T 1 ] ,
18+
```
19+
where `R` is a 3x3 rotation matrix and `T` is a 3-D translation vector.
20+
SE(3) matrices are commonly used to represent rigid motions or camera extrinsics.
21+
22+
In the SE(3) logarithmic representation SE(3) matrices are
23+
represented as 6-dimensional vectors `[log_translation | log_rotation]`,
24+
i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`.
25+
26+
The conversion from the 6D representation to a 4x4 SE(3) matrix `transform`
27+
is done as follows:
28+
```
29+
transform = exp( [ hat(log_rotation) 0 ]
30+
[ log_translation 1 ] ) ,
31+
```
32+
where `exp` is the matrix exponential and `hat` is the Hat operator [2].
33+
34+
Note that for any `log_transform` with `0 <= ||log_rotation|| < 2pi`
35+
(i.e. the rotation angle is between 0 and 2pi), the following identity holds:
36+
```
37+
se3_log_map(se3_exponential_map(log_transform)) == log_transform
38+
```
39+
40+
The conversion has a singularity around `||log(transform)|| = 0`
41+
which is handled by clamping controlled with the `eps` argument.
42+
43+
Args:
44+
log_transform: Batch of vectors of shape `(minibatch, 6)`.
45+
eps: A threshold for clipping the squared norm of the rotation logarithm
46+
to avoid unstable gradients in the singular case.
47+
48+
Returns:
49+
Batch of transformation matrices of shape `(minibatch, 4, 4)`.
50+
51+
Raises:
52+
ValueError if `log_transform` is of incorrect shape.
53+
54+
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
55+
[2] https://en.wikipedia.org/wiki/Hat_operator
56+
"""
57+
58+
if log_transform.ndim != 2 or log_transform.shape[1] != 6:
59+
raise ValueError("Expected input to be of shape (N, 6).")
60+
61+
N, _ = log_transform.shape
62+
63+
log_translation = log_transform[..., :3]
64+
log_rotation = log_transform[..., 3:]
65+
66+
# rotation is an exponential map of log_rotation
67+
(
68+
R,
69+
rotation_angles,
70+
log_rotation_hat,
71+
log_rotation_hat_square,
72+
) = _so3_exp_map(log_rotation, eps=eps)
73+
74+
# translation is V @ T
75+
V = _se3_V_matrix(
76+
log_rotation,
77+
log_rotation_hat,
78+
log_rotation_hat_square,
79+
rotation_angles,
80+
eps=eps,
81+
)
82+
T = torch.bmm(V, log_translation[:, :, None])[:, :, 0]
83+
84+
transform = torch.zeros(
85+
N, 4, 4, dtype=log_transform.dtype, device=log_transform.device
86+
)
87+
88+
transform[:, :3, :3] = R
89+
transform[:, :3, 3] = T
90+
transform[:, 3, 3] = 1.0
91+
92+
return transform.permute(0, 2, 1)
93+
94+
95+
def se3_log_map(
96+
transform: torch.Tensor, eps: float = 1e-4, cos_bound: float = 1e-4
97+
) -> torch.Tensor:
98+
"""
99+
Convert a batch of 4x4 transformation matrices `transform`
100+
to a batch of 6-dimensional SE(3) logarithms of the SE(3) matrices.
101+
See e.g. [1], Sec 9.4.2. for more detailed description.
102+
103+
A SE(3) matrix has the following form:
104+
```
105+
[ R 0 ]
106+
[ T 1 ] ,
107+
```
108+
where `R` is an orthonormal 3x3 rotation matrix and `T` is a 3-D translation vector.
109+
SE(3) matrices are commonly used to represent rigid motions or camera extrinsics.
110+
111+
In the SE(3) logarithmic representation SE(3) matrices are
112+
represented as 6-dimensional vectors `[log_translation | log_rotation]`,
113+
i.e. a concatenation of two 3D vectors `log_translation` and `log_rotation`.
114+
115+
The conversion from the 4x4 SE(3) matrix `transform` to the
116+
6D representation `log_transform = [log_translation | log_rotation]`
117+
is done as follows:
118+
```
119+
log_transform = log(transform)
120+
log_translation = log_transform[3, :3]
121+
log_rotation = inv_hat(log_transform[:3, :3])
122+
```
123+
where `log` is the matrix logarithm
124+
and `inv_hat` is the inverse of the Hat operator [2].
125+
126+
Note that for any valid 4x4 `transform` matrix, the following identity holds:
127+
```
128+
se3_exp_map(se3_log_map(transform)) == transform
129+
```
130+
131+
The conversion has a singularity around `(transform=I)` which is handled
132+
by clamping controlled with the `eps` and `cos_bound` arguments.
133+
134+
Args:
135+
transform: batch of SE(3) matrices of shape `(minibatch, 4, 4)`.
136+
eps: A threshold for clipping the squared norm of the rotation logarithm
137+
to avoid division by zero in the singular case.
138+
cos_bound: Clamps the cosine of the rotation angle to
139+
[-1 + cos_bound, 3 - cos_bound] to avoid non-finite outputs.
140+
The non-finite outputs can be caused by passing small rotation angles
141+
to the `acos` function in `so3_rotation_angle` of `so3_log_map`.
142+
143+
Returns:
144+
Batch of logarithms of input SE(3) matrices
145+
of shape `(minibatch, 6)`.
146+
147+
Raises:
148+
ValueError if `transform` is of incorrect shape.
149+
ValueError if `R` has an unexpected trace.
150+
151+
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
152+
[2] https://en.wikipedia.org/wiki/Hat_operator
153+
"""
154+
155+
if transform.ndim != 3:
156+
raise ValueError("Input tensor shape has to be (N, 4, 4).")
157+
158+
N, dim1, dim2 = transform.shape
159+
if dim1 != 4 or dim2 != 4:
160+
raise ValueError("Input tensor shape has to be (N, 4, 4).")
161+
162+
if not torch.allclose(transform[:, :3, 3], torch.zeros_like(transform[:, :3, 3])):
163+
raise ValueError("All elements of `transform[:, :3, 3]` should be 0.")
164+
165+
# log_rot is just so3_log_map of the upper left 3x3 block
166+
R = transform[:, :3, :3].permute(0, 2, 1)
167+
log_rotation = so3_log_map(R, eps=eps, cos_bound=cos_bound)
168+
169+
# log_translation is V^-1 @ T
170+
T = transform[:, 3, :3]
171+
V = _se3_V_matrix(*_get_se3_V_input(log_rotation), eps=eps)
172+
log_translation = torch.linalg.solve(V, T[:, :, None])[:, :, 0]
173+
174+
return torch.cat((log_translation, log_rotation), dim=1)
175+
176+
177+
def _se3_V_matrix(
178+
log_rotation: torch.Tensor,
179+
log_rotation_hat: torch.Tensor,
180+
log_rotation_hat_square: torch.Tensor,
181+
rotation_angles: torch.Tensor,
182+
eps: float = 1e-4,
183+
) -> torch.Tensor:
184+
"""
185+
A helper function that computes the "V" matrix from [1], Sec 9.4.2.
186+
[1] https://jinyongjeong.github.io/Download/SE3/jlblanco2010geometry3d_techrep.pdf
187+
"""
188+
189+
V = (
190+
torch.eye(3, dtype=log_rotation.dtype, device=log_rotation.device)[None]
191+
+ log_rotation_hat
192+
* ((1 - torch.cos(rotation_angles)) / (rotation_angles ** 2))[:, None, None]
193+
+ (
194+
log_rotation_hat_square
195+
* ((rotation_angles - torch.sin(rotation_angles)) / (rotation_angles ** 3))[
196+
:, None, None
197+
]
198+
)
199+
)
200+
201+
return V
202+
203+
204+
def _get_se3_V_input(log_rotation: torch.Tensor, eps: float = 1e-4):
205+
"""
206+
A helper function that computes the input variables to the `_se3_V_matrix`
207+
function.
208+
"""
209+
nrms = (log_rotation ** 2).sum(-1)
210+
rotation_angles = torch.clamp(nrms, eps).sqrt()
211+
log_rotation_hat = hat(log_rotation)
212+
log_rotation_hat_square = torch.bmm(log_rotation_hat, log_rotation_hat)
213+
return log_rotation, log_rotation_hat, log_rotation_hat_square, rotation_angles

pytorch3d/transforms/so3.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def so3_relative_angle(
1414
R2: torch.Tensor,
1515
cos_angle: bool = False,
1616
cos_bound: float = 1e-4,
17+
eps: float = 1e-4,
1718
) -> torch.Tensor:
1819
"""
1920
Calculates the relative angle (in radians) between pairs of
@@ -33,7 +34,8 @@ def so3_relative_angle(
3334
of the `acos` call. Note that the non-finite outputs/gradients
3435
are returned when the angle is requested (i.e. `cos_angle==False`)
3536
and the rotation angle is close to 0 or π.
36-
37+
eps: Tolerance for the valid trace check of the relative rotation matrix
38+
in `so3_rotation_angle`.
3739
Returns:
3840
Corresponding rotation angles of shape `(minibatch,)`.
3941
If `cos_angle==True`, returns the cosine of the angles.
@@ -43,7 +45,7 @@ def so3_relative_angle(
4345
ValueError if `R1` or `R2` has an unexpected trace.
4446
"""
4547
R12 = torch.bmm(R1, R2.permute(0, 2, 1))
46-
return so3_rotation_angle(R12, cos_angle=cos_angle, cos_bound=cos_bound)
48+
return so3_rotation_angle(R12, cos_angle=cos_angle, cos_bound=cos_bound, eps=eps)
4749

4850

4951
def so3_rotation_angle(

tests/bm_se3.py

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
2+
3+
from fvcore.common.benchmark import benchmark
4+
from test_se3 import TestSE3
5+
6+
7+
def bm_se3() -> None:
8+
kwargs_list = [
9+
{"batch_size": 1},
10+
{"batch_size": 10},
11+
{"batch_size": 100},
12+
{"batch_size": 1000},
13+
]
14+
benchmark(TestSE3.se3_expmap, "SE3_EXP", kwargs_list, warmup_iters=1)
15+
benchmark(TestSE3.se3_logmap, "SE3_LOG", kwargs_list, warmup_iters=1)
16+
17+
18+
if __name__ == "__main__":
19+
bm_se3()

0 commit comments

Comments
 (0)