|
| 1 | +# Copyright (c) Facebook, Inc. and its affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | + |
| 8 | +import unittest |
| 9 | + |
| 10 | +import numpy as np |
| 11 | +import torch |
| 12 | +from common_testing import TestCaseMixin |
| 13 | +from pytorch3d.common.workaround import _safe_det_3x3 |
| 14 | + |
| 15 | + |
| 16 | +class TestSafeDet3x3(TestCaseMixin, unittest.TestCase): |
| 17 | + def setUp(self) -> None: |
| 18 | + super().setUp() |
| 19 | + torch.manual_seed(42) |
| 20 | + np.random.seed(42) |
| 21 | + |
| 22 | + def _test_det_3x3(self, batch_size, device): |
| 23 | + t = torch.rand((batch_size, 3, 3), dtype=torch.float32, device=device) |
| 24 | + actual_det = _safe_det_3x3(t) |
| 25 | + expected_det = t.det() |
| 26 | + self.assertClose(actual_det, expected_det, atol=1e-7) |
| 27 | + |
| 28 | + def test_empty_batch(self): |
| 29 | + self._test_det_3x3(0, torch.device("cpu")) |
| 30 | + self._test_det_3x3(0, torch.device("cuda:0")) |
| 31 | + |
| 32 | + def test_manual(self): |
| 33 | + t = torch.Tensor( |
| 34 | + [ |
| 35 | + [[1, 0, 0], [0, 1, 0], [0, 0, 1]], |
| 36 | + [[2, -5, 3], [0, 7, -2], [-1, 4, 1]], |
| 37 | + [[6, 1, 1], [4, -2, 5], [2, 8, 7]], |
| 38 | + ] |
| 39 | + ).to(dtype=torch.float32) |
| 40 | + expected_det = torch.Tensor([1, 41, -306]).to(dtype=torch.float32) |
| 41 | + self.assertClose(_safe_det_3x3(t), expected_det) |
| 42 | + |
| 43 | + device_cuda = torch.device("cuda:0") |
| 44 | + self.assertClose( |
| 45 | + _safe_det_3x3(t.to(device=device_cuda)), expected_det.to(device=device_cuda) |
| 46 | + ) |
| 47 | + |
| 48 | + def test_regression(self): |
| 49 | + tries = 32 |
| 50 | + device_cpu = torch.device("cpu") |
| 51 | + device_cuda = torch.device("cuda:0") |
| 52 | + batch_sizes = np.random.randint(low=1, high=128, size=tries) |
| 53 | + |
| 54 | + for batch_size in batch_sizes: |
| 55 | + self._test_det_3x3(batch_size, device_cpu) |
| 56 | + self._test_det_3x3(batch_size, device_cuda) |
0 commit comments