Skip to content

Commit 46f727c

Browse files
shchhan123456facebook-github-bot
authored andcommitted
make so3_log_map torch script compatible
Summary: * HAT_INV_SKEW_SYMMETRIC_TOL was a global variable and torch script gives an error when compiling that function. Move it to the function scope. * torch script gives error when compiling acos_linear_extrapolation because bound is a union of tuple and float. The tuple version is kept in this diff. Reviewed By: patricklabatut Differential Revision: D30614916 fbshipit-source-id: 34258d200dc6a09fbf8917cac84ba8a269c00aef
1 parent c3d7808 commit 46f727c

File tree

4 files changed

+22
-26
lines changed

4 files changed

+22
-26
lines changed

pytorch3d/transforms/math.py

+12-17
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,12 @@
99

1010
import torch
1111

12+
DEFAULT_ACOS_BOUND = 1.0 - 1e-4
13+
1214

1315
def acos_linear_extrapolation(
1416
x: torch.Tensor,
15-
bound: Union[float, Tuple[float, float]] = 1.0 - 1e-4,
17+
bounds: Tuple[float, float] = (-DEFAULT_ACOS_BOUND, DEFAULT_ACOS_BOUND),
1618
) -> torch.Tensor:
1719
"""
1820
Implements `arccos(x)` which is linearly extrapolated outside `x`'s original
@@ -21,23 +23,20 @@ def acos_linear_extrapolation(
2123
2224
More specifically:
2325
```
24-
if -bound <= x <= bound:
26+
bounds=(lower_bound, upper_bound)
27+
if lower_bound <= x <= upper_bound:
2528
acos_linear_extrapolation(x) = acos(x)
26-
elif x <= -bound: # 1st order Taylor approximation
27-
acos_linear_extrapolation(x) = acos(-bound) + dacos/dx(-bound) * (x - (-bound))
28-
else: # x >= bound
29-
acos_linear_extrapolation(x) = acos(bound) + dacos/dx(bound) * (x - bound)
29+
elif x <= lower_bound: # 1st order Taylor approximation
30+
acos_linear_extrapolation(x) = acos(lower_bound) + dacos/dx(lower_bound) * (x - lower_bound)
31+
else: # x >= upper_bound
32+
acos_linear_extrapolation(x) = acos(upper_bound) + dacos/dx(upper_bound) * (x - upper_bound)
3033
```
31-
Note that `bound` can be made more specific with setting
32-
`bound=[lower_bound, upper_bound]` as detailed below.
3334
3435
Args:
3536
x: Input `Tensor`.
36-
bound: A float constant or a float 2-tuple defining the region for the
37+
bounds: A float 2-tuple defining the region for the
3738
linear extrapolation of `acos`.
38-
If `bound` is a float scalar, linearly interpolates acos for
39-
`x <= -bound` or `bound <= x`.
40-
If `bound` is a 2-tuple, the first/second element of `bound`
39+
The first/second element of `bound`
4140
describes the lower/upper bound that defines the lower/upper
4241
extrapolation region, i.e. the region where
4342
`x <= bound[0]`/`bound[1] <= x`.
@@ -46,11 +45,7 @@ def acos_linear_extrapolation(
4645
acos_linear_extrapolation: `Tensor` containing the extrapolated `arccos(x)`.
4746
"""
4847

49-
if isinstance(bound, float):
50-
upper_bound = bound
51-
lower_bound = -bound
52-
else:
53-
lower_bound, upper_bound = bound
48+
lower_bound, upper_bound = bounds
5449

5550
if lower_bound > upper_bound:
5651
raise ValueError("lower bound has to be smaller or equal to upper bound.")

pytorch3d/transforms/so3.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@
1212
from ..transforms import acos_linear_extrapolation
1313

1414

15-
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
16-
17-
1815
def so3_relative_angle(
1916
R1: torch.Tensor,
2017
R2: torch.Tensor,
@@ -104,7 +101,8 @@ def so3_rotation_angle(
104101
return phi_cos
105102
else:
106103
if cos_bound > 0.0:
107-
return acos_linear_extrapolation(phi_cos, 1.0 - cos_bound)
104+
bound = 1.0 - cos_bound
105+
return acos_linear_extrapolation(phi_cos, (-bound, bound))
108106
else:
109107
return torch.acos(phi_cos)
110108

@@ -250,6 +248,8 @@ def hat_inv(h: torch.Tensor) -> torch.Tensor:
250248
raise ValueError("Input has to be a batch of 3x3 Tensors.")
251249

252250
ss_diff = torch.abs(h + h.permute(0, 2, 1)).max()
251+
252+
HAT_INV_SKEW_SYMMETRIC_TOL = 1e-5
253253
if float(ss_diff) > HAT_INV_SKEW_SYMMETRIC_TOL:
254254
raise ValueError("One of input matrices is not skew-symmetric.")
255255

tests/test_acos_linear_extrapolation.py

-5
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,6 @@ def _one_acos_test(self, x: torch.Tensor, lower_bound: float, upper_bound: float
101101
self._test_acos_outside_bounds(
102102
x[x_lower], y[x_lower], dacos_dx[x_lower], lower_bound
103103
)
104-
if abs(upper_bound + lower_bound) <= 1e-5: # lower_bound==-upper_bound
105-
# check that passing bounds=upper_bound gives the same
106-
# resut as bounds=[lower_bound, upper_bound]
107-
y_one_bound = acos_linear_extrapolation(x, upper_bound)
108-
self.assertClose(y_one_bound, y)
109104

110105
def test_acos(self, batch_size: int = 10000):
111106
"""

tests/test_so3.py

+6
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
import math
99
import unittest
10+
from distutils.version import LooseVersion
1011

1112
import numpy as np
1213
import torch
@@ -268,6 +269,11 @@ def test_so3_cos_bound(self, batch_size: int = 100):
268269
# all grad values have to be finite
269270
self.assertTrue(torch.isfinite(r.grad).all())
270271

272+
@unittest.skipIf(LooseVersion(torch.__version__) < "1.9", "recent torchscript only")
273+
def test_scriptable(self):
274+
torch.jit.script(so3_exp_map)
275+
torch.jit.script(so3_log_map)
276+
271277
@staticmethod
272278
def so3_expmap(batch_size: int = 10):
273279
log_rot = TestSO3.init_log_rot(batch_size=batch_size)

0 commit comments

Comments
 (0)