|
5 | 5 | import unittest
|
6 | 6 |
|
7 | 7 | import torch
|
| 8 | +from common_testing import TestCaseMixin |
8 | 9 | from pytorch3d.transforms.so3 import so3_exponential_map
|
9 | 10 | from pytorch3d.transforms.transform3d import (
|
10 | 11 | Rotate,
|
|
15 | 16 | )
|
16 | 17 |
|
17 | 18 |
|
18 |
| -class TestTransform(unittest.TestCase): |
| 19 | +class TestTransform(TestCaseMixin, unittest.TestCase): |
19 | 20 | def test_to(self):
|
20 | 21 | tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
|
21 | 22 | R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
|
@@ -89,6 +90,22 @@ def test_translate(self):
|
89 | 90 | self.assertTrue(torch.allclose(points_out, points_out_expected))
|
90 | 91 | self.assertTrue(torch.allclose(normals_out, normals_out_expected))
|
91 | 92 |
|
| 93 | + def test_rotate(self): |
| 94 | + R = so3_exponential_map(torch.randn((1, 3))) |
| 95 | + t = Transform3d().rotate(R) |
| 96 | + points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view( |
| 97 | + 1, 3, 3 |
| 98 | + ) |
| 99 | + normals = torch.tensor( |
| 100 | + [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]] |
| 101 | + ).view(1, 3, 3) |
| 102 | + points_out = t.transform_points(points) |
| 103 | + normals_out = t.transform_normals(normals) |
| 104 | + points_out_expected = torch.bmm(points, R) |
| 105 | + normals_out_expected = torch.bmm(normals, R) |
| 106 | + self.assertTrue(torch.allclose(points_out, points_out_expected)) |
| 107 | + self.assertTrue(torch.allclose(normals_out, normals_out_expected)) |
| 108 | + |
92 | 109 | def test_scale(self):
|
93 | 110 | t = Transform3d().scale(2.0).scale(0.5, 0.25, 1.0)
|
94 | 111 | points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
|
@@ -237,6 +254,103 @@ def test_inverse(self, batch_size=5):
|
237 | 254 | for m in (m1, m2, m3, m4):
|
238 | 255 | self.assertTrue(torch.allclose(m, m5, atol=1e-3))
|
239 | 256 |
|
| 257 | + def _check_indexed_transforms(self, t3d, t3d_selected, indices): |
| 258 | + t3d_matrix = t3d.get_matrix() |
| 259 | + t3d_selected_matrix = t3d_selected.get_matrix() |
| 260 | + for order_index, selected_index in indices: |
| 261 | + self.assertClose( |
| 262 | + t3d_matrix[selected_index], t3d_selected_matrix[order_index] |
| 263 | + ) |
| 264 | + |
| 265 | + def test_get_item(self, batch_size=5): |
| 266 | + device = torch.device("cuda:0") |
| 267 | + |
| 268 | + matrices = torch.randn( |
| 269 | + size=[batch_size, 4, 4], device=device, dtype=torch.float32 |
| 270 | + ) |
| 271 | + |
| 272 | + # init the Transforms3D class |
| 273 | + t3d = Transform3d(matrix=matrices) |
| 274 | + |
| 275 | + # int index |
| 276 | + index = 1 |
| 277 | + t3d_selected = t3d[index] |
| 278 | + self.assertEqual(len(t3d_selected), 1) |
| 279 | + self._check_indexed_transforms(t3d, t3d_selected, [(0, 1)]) |
| 280 | + |
| 281 | + # negative int index |
| 282 | + index = -1 |
| 283 | + t3d_selected = t3d[index] |
| 284 | + self.assertEqual(len(t3d_selected), 1) |
| 285 | + self._check_indexed_transforms(t3d, t3d_selected, [(0, -1)]) |
| 286 | + |
| 287 | + # list index |
| 288 | + index = [1, 2] |
| 289 | + t3d_selected = t3d[index] |
| 290 | + self.assertEqual(len(t3d_selected), len(index)) |
| 291 | + self._check_indexed_transforms(t3d, t3d_selected, enumerate(index)) |
| 292 | + |
| 293 | + # empty list index |
| 294 | + index = [] |
| 295 | + t3d_selected = t3d[index] |
| 296 | + self.assertEqual(len(t3d_selected), 0) |
| 297 | + self.assertEqual(t3d_selected.get_matrix().nelement(), 0) |
| 298 | + |
| 299 | + # slice index |
| 300 | + index = slice(0, 2, 1) |
| 301 | + t3d_selected = t3d[index] |
| 302 | + self.assertEqual(len(t3d_selected), 2) |
| 303 | + self._check_indexed_transforms(t3d, t3d_selected, [(0, 0), (1, 1)]) |
| 304 | + |
| 305 | + # empty slice index |
| 306 | + index = slice(0, 0, 1) |
| 307 | + t3d_selected = t3d[index] |
| 308 | + self.assertEqual(len(t3d_selected), 0) |
| 309 | + self.assertEqual(t3d_selected.get_matrix().nelement(), 0) |
| 310 | + |
| 311 | + # bool tensor |
| 312 | + index = (torch.rand(batch_size) > 0.5).to(device) |
| 313 | + index[:2] = True # make sure smth is selected |
| 314 | + t3d_selected = t3d[index] |
| 315 | + self.assertEqual(len(t3d_selected), index.sum()) |
| 316 | + self._check_indexed_transforms( |
| 317 | + t3d, |
| 318 | + t3d_selected, |
| 319 | + zip( |
| 320 | + torch.arange(index.sum()), |
| 321 | + torch.nonzero(index, as_tuple=False).squeeze(), |
| 322 | + ), |
| 323 | + ) |
| 324 | + |
| 325 | + # all false bool tensor |
| 326 | + index = torch.zeros(batch_size).bool() |
| 327 | + t3d_selected = t3d[index] |
| 328 | + self.assertEqual(len(t3d_selected), 0) |
| 329 | + self.assertEqual(t3d_selected.get_matrix().nelement(), 0) |
| 330 | + |
| 331 | + # int tensor |
| 332 | + index = torch.tensor([1, 2], dtype=torch.int64, device=device) |
| 333 | + t3d_selected = t3d[index] |
| 334 | + self.assertEqual(len(t3d_selected), index.numel()) |
| 335 | + self._check_indexed_transforms(t3d, t3d_selected, enumerate(index.tolist())) |
| 336 | + |
| 337 | + # negative int tensor |
| 338 | + index = -(torch.tensor([1, 2], dtype=torch.int64, device=device)) |
| 339 | + t3d_selected = t3d[index] |
| 340 | + self.assertEqual(len(t3d_selected), index.numel()) |
| 341 | + self._check_indexed_transforms(t3d, t3d_selected, enumerate(index.tolist())) |
| 342 | + |
| 343 | + # invalid index |
| 344 | + for invalid_index in ( |
| 345 | + torch.tensor([1, 0, 1], dtype=torch.float32, device=device), # float tensor |
| 346 | + 1.2, # float index |
| 347 | + torch.tensor( |
| 348 | + [[1, 0, 1], [1, 0, 1]], dtype=torch.int32, device=device |
| 349 | + ), # multidimensional tensor |
| 350 | + ): |
| 351 | + with self.assertRaises(IndexError): |
| 352 | + t3d_selected = t3d[invalid_index] |
| 353 | + |
240 | 354 |
|
241 | 355 | class TestTranslate(unittest.TestCase):
|
242 | 356 | def test_python_scalar(self):
|
|
0 commit comments