Skip to content

Commit 1b39ceb

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Sign issue about quaternion_to_matrix and matrix_to_quaternion
Summary: As reported on github, `matrix_to_quaternion` was incorrect for rotations by 180˚. We resolved the sign of the component `i` based on the sign of `i*r`, assuming `r > 0`, which is untrue if `r == 0`. This diff handles special cases and ensures we use the non-zero elements to copy the sign from. Reviewed By: bottler Differential Revision: D29149465 fbshipit-source-id: cd508cc31567fc37ea3463dd7e8c8e8d5d64a235
1 parent a8610e9 commit 1b39ceb

File tree

2 files changed

+89
-17
lines changed

2 files changed

+89
-17
lines changed

pytorch3d/transforms/rotation_conversions.py

+40-13
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _copysign(a, b):
8282
return torch.where(signs_differ, -a, a)
8383

8484

85-
def _sqrt_positive_part(x):
85+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
8686
"""
8787
Returns torch.sqrt(torch.max(0, x))
8888
but with a zero subgradient where x is 0.
@@ -93,7 +93,7 @@ def _sqrt_positive_part(x):
9393
return ret
9494

9595

96-
def matrix_to_quaternion(matrix):
96+
def matrix_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
9797
"""
9898
Convert rotations given as rotation matrices to quaternions.
9999
@@ -105,17 +105,44 @@ def matrix_to_quaternion(matrix):
105105
"""
106106
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
107107
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
108-
m00 = matrix[..., 0, 0]
109-
m11 = matrix[..., 1, 1]
110-
m22 = matrix[..., 2, 2]
111-
o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
112-
x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
113-
y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
114-
z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
115-
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
116-
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
117-
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
118-
return torch.stack((o0, o1, o2, o3), -1)
108+
109+
batch_dim = matrix.shape[:-2]
110+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
111+
matrix.reshape(*batch_dim, 9), dim=-1
112+
)
113+
114+
q_abs = _sqrt_positive_part(
115+
torch.stack(
116+
[
117+
1.0 + m00 + m11 + m22,
118+
1.0 + m00 - m11 - m22,
119+
1.0 - m00 + m11 - m22,
120+
1.0 - m00 - m11 + m22,
121+
],
122+
dim=-1,
123+
)
124+
)
125+
126+
# we produce the desired quaternion multiplied by each of r, i, j, k
127+
quat_by_rijk = torch.stack(
128+
[
129+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
130+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
131+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
132+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
133+
],
134+
dim=-2,
135+
)
136+
137+
# clipping is not important here; if q_abs is small, the candidate won't be picked
138+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].clip(0.1))
139+
140+
# if not for numerical problems, quat_candidates[i] should be same (up to a sign),
141+
# forall i; we pick the best-conditioned one (with the largest denominator)
142+
143+
return quat_candidates[
144+
F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : # pyre-ignore[16]
145+
].reshape(*batch_dim, 4)
119146

120147

121148
def _axis_angle_rotation(axis: str, angle):

tests/test_rotation_conversions.py

+49-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
import itertools
55
import math
66
import unittest
7+
from typing import Optional, Union
78

9+
import numpy as np
810
import torch
911
from common_testing import TestCaseMixin
1012
from pytorch3d.transforms.rotation_conversions import (
@@ -64,7 +66,7 @@ def test_from_quat(self):
6466
"""quat -> mtx -> quat"""
6567
data = random_quaternions(13, dtype=torch.float64)
6668
mdata = matrix_to_quaternion(quaternion_to_matrix(data))
67-
self.assertClose(data, mdata)
69+
self._assert_quaternions_close(data, mdata)
6870

6971
def test_to_quat(self):
7072
"""mtx -> quat -> mtx"""
@@ -146,8 +148,7 @@ def test_quaternion_multiplication(self):
146148
b_matrix = quaternion_to_matrix(b)
147149
ab_matrix = torch.matmul(a_matrix, b_matrix)
148150
ab_from_matrix = matrix_to_quaternion(ab_matrix)
149-
self.assertEqual(ab.shape, ab_from_matrix.shape)
150-
self.assertClose(ab, ab_from_matrix)
151+
self._assert_quaternions_close(ab, ab_from_matrix)
151152

152153
def test_matrix_to_quaternion_corner_case(self):
153154
"""Check no bad gradients from sqrt(0)."""
@@ -161,7 +162,34 @@ def test_matrix_to_quaternion_corner_case(self):
161162
loss.backward()
162163
optimizer.step()
163164

164-
self.assertClose(matrix, 0.95 * torch.eye(3))
165+
self.assertClose(matrix, matrix, msg="Result has non-finite values")
166+
delta = 1e-2
167+
self.assertLess(
168+
matrix.trace(),
169+
3.0 - delta,
170+
msg="Identity initialisation unchanged by a gradient step",
171+
)
172+
173+
def test_matrix_to_quaternion_by_pi(self):
174+
# We check that rotations by pi around each of the 26
175+
# nonzero vectors containing nothing but 0, 1 and -1
176+
# are mapped to the right quaternions.
177+
# This is representative across the directions.
178+
options = [0.0, -1.0, 1.0]
179+
axes = [
180+
torch.tensor(vec)
181+
for vec in itertools.islice( # exclude [0, 0, 0]
182+
itertools.product(options, options, options), 1, None
183+
)
184+
]
185+
186+
axes = torch.nn.functional.normalize(torch.stack(axes), dim=-1)
187+
# Rotation by pi around unit vector x is given by
188+
# the matrix 2 x x^T - Id.
189+
R = 2 * torch.matmul(axes[..., None], axes[..., None, :]) - torch.eye(3)
190+
quats_hat = matrix_to_quaternion(R)
191+
R_hat = quaternion_to_matrix(quats_hat)
192+
self.assertClose(R, R_hat, atol=1e-3)
165193

166194
def test_from_axis_angle(self):
167195
"""axis_angle -> mtx -> axis_angle"""
@@ -228,3 +256,20 @@ def test_6d(self):
228256
self.assertClose(
229257
torch.matmul(r, r.permute(0, 2, 1)), torch.eye(3).expand_as(r), atol=1e-6
230258
)
259+
260+
def _assert_quaternions_close(
261+
self,
262+
input: Union[torch.Tensor, np.ndarray],
263+
other: Union[torch.Tensor, np.ndarray],
264+
*,
265+
rtol: float = 1e-05,
266+
atol: float = 1e-08,
267+
equal_nan: bool = False,
268+
msg: Optional[str] = None,
269+
):
270+
self.assertEqual(np.shape(input), np.shape(other))
271+
dot = (input * other).sum(-1)
272+
ones = torch.ones_like(dot)
273+
self.assertClose(
274+
dot.abs(), ones, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=msg
275+
)

0 commit comments

Comments
 (0)