Skip to content

Commit bcee361

Browse files
Alexey Sidnevfacebook-github-bot
Alexey Sidnev
authored andcommitted
Replace torch.det() with manual implementation for 3x3 matrix
Summary: # Background There is an unstable error during training (it can happen after several minutes or after several hours). The error is connected to `torch.det()` function in `_check_valid_rotation_matrix()`. if I remove the function `torch.det()` in `_check_valid_rotation_matrix()` or remove the whole functions `_check_valid_rotation_matrix()` the error is disappeared (D29555876). # Solution Replace `torch.det()` with manual implementation for 3x3 matrix. Reviewed By: patricklabatut Differential Revision: D29655924 fbshipit-source-id: 41bde1119274a705ab849751ece28873d2c45155
1 parent 2f668ec commit bcee361

File tree

3 files changed

+89
-1
lines changed

3 files changed

+89
-1
lines changed

pytorch3d/common/workaround.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import torch
9+
10+
11+
def _safe_det_3x3(t: torch.Tensor):
12+
"""
13+
Fast determinant calculation for a batch of 3x3 matrices.
14+
15+
Note, result of this function might not be the same as `torch.det()`.
16+
The differences might be in the last significant digit.
17+
18+
Args:
19+
t: Tensor of shape (N, 3, 3).
20+
21+
Returns:
22+
Tensor of shape (N) with determinants.
23+
"""
24+
25+
det = (
26+
t[..., 0, 0] * (t[..., 1, 1] * t[..., 2, 2] - t[..., 1, 2] * t[..., 2, 1])
27+
- t[..., 0, 1] * (t[..., 1, 0] * t[..., 2, 2] - t[..., 2, 0] * t[..., 1, 2])
28+
+ t[..., 0, 2] * (t[..., 1, 0] * t[..., 2, 1] - t[..., 2, 0] * t[..., 1, 1])
29+
)
30+
31+
return det

pytorch3d/transforms/transform3d.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212

1313
from ..common.types import Device, get_device, make_device
14+
from ..common.workaround import _safe_det_3x3
1415
from .rotation_conversions import _axis_angle_rotation
1516

1617

@@ -774,7 +775,7 @@ def _check_valid_rotation_matrix(R, tol: float = 1e-7):
774775
eye = torch.eye(3, dtype=R.dtype, device=R.device)
775776
eye = eye.view(1, 3, 3).expand(N, -1, -1)
776777
orthogonal = torch.allclose(R.bmm(R.transpose(1, 2)), eye, atol=tol)
777-
det_R = torch.det(R)
778+
det_R = _safe_det_3x3(R)
778779
no_distortion = torch.allclose(det_R, torch.ones_like(det_R))
779780
if not (orthogonal and no_distortion):
780781
msg = "R is not a valid rotation matrix"

tests/test_common_workaround.py

+56
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
import unittest
9+
10+
import numpy as np
11+
import torch
12+
from common_testing import TestCaseMixin
13+
from pytorch3d.common.workaround import _safe_det_3x3
14+
15+
16+
class TestSafeDet3x3(TestCaseMixin, unittest.TestCase):
17+
def setUp(self) -> None:
18+
super().setUp()
19+
torch.manual_seed(42)
20+
np.random.seed(42)
21+
22+
def _test_det_3x3(self, batch_size, device):
23+
t = torch.rand((batch_size, 3, 3), dtype=torch.float32, device=device)
24+
actual_det = _safe_det_3x3(t)
25+
expected_det = t.det()
26+
self.assertClose(actual_det, expected_det, atol=1e-7)
27+
28+
def test_empty_batch(self):
29+
self._test_det_3x3(0, torch.device("cpu"))
30+
self._test_det_3x3(0, torch.device("cuda:0"))
31+
32+
def test_manual(self):
33+
t = torch.Tensor(
34+
[
35+
[[1, 0, 0], [0, 1, 0], [0, 0, 1]],
36+
[[2, -5, 3], [0, 7, -2], [-1, 4, 1]],
37+
[[6, 1, 1], [4, -2, 5], [2, 8, 7]],
38+
]
39+
).to(dtype=torch.float32)
40+
expected_det = torch.Tensor([1, 41, -306]).to(dtype=torch.float32)
41+
self.assertClose(_safe_det_3x3(t), expected_det)
42+
43+
device_cuda = torch.device("cuda:0")
44+
self.assertClose(
45+
_safe_det_3x3(t.to(device=device_cuda)), expected_det.to(device=device_cuda)
46+
)
47+
48+
def test_regression(self):
49+
tries = 32
50+
device_cpu = torch.device("cpu")
51+
device_cuda = torch.device("cuda:0")
52+
batch_sizes = np.random.randint(low=1, high=128, size=tries)
53+
54+
for batch_size in batch_sizes:
55+
self._test_det_3x3(batch_size, device_cpu)
56+
self._test_det_3x3(batch_size, device_cuda)

0 commit comments

Comments
 (0)