Skip to content

Commit 1e4a2e8

Browse files
davnov134facebook-github-bot
authored andcommitted
__getitem__ for Transform3D
Summary: Implements the `__getitem__` method for `Transform3D` Reviewed By: nikhilaravi Differential Revision: D23813975 fbshipit-source-id: 5da752ed8ea029ad0af58bb7a7856f0995519b7a
1 parent ac3f8dc commit 1e4a2e8

File tree

2 files changed

+135
-2
lines changed

2 files changed

+135
-2
lines changed

pytorch3d/transforms/transform3d.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import math
44
import warnings
5-
from typing import Optional
5+
from typing import List, Optional, Union
66

77
import torch
88

@@ -172,6 +172,22 @@ def __init__(
172172
def __len__(self):
173173
return self.get_matrix().shape[0]
174174

175+
def __getitem__(
176+
self, index: Union[int, List[int], slice, torch.Tensor]
177+
) -> "Transform3d":
178+
"""
179+
Args:
180+
index: Specifying the index of the transform to retrieve.
181+
Can be an int, slice, list of ints, boolean, long tensor.
182+
Supports negative indices.
183+
184+
Returns:
185+
Transform3d object with selected transforms. The tensors are not cloned.
186+
"""
187+
if isinstance(index, int):
188+
index = [index]
189+
return self.__class__(matrix=self.get_matrix()[index])
190+
175191
def compose(self, *others):
176192
"""
177193
Return a new Transform3d with the tranforms to compose stored as
@@ -361,6 +377,9 @@ def translate(self, *args, **kwargs):
361377
def scale(self, *args, **kwargs):
362378
return self.compose(Scale(device=self.device, *args, **kwargs))
363379

380+
def rotate(self, *args, **kwargs):
381+
return self.compose(Rotate(device=self.device, *args, **kwargs))
382+
364383
def rotate_axis_angle(self, *args, **kwargs):
365384
return self.compose(RotateAxisAngle(device=self.device, *args, **kwargs))
366385

tests/test_transforms.py

+115-1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import unittest
66

77
import torch
8+
from common_testing import TestCaseMixin
89
from pytorch3d.transforms.so3 import so3_exponential_map
910
from pytorch3d.transforms.transform3d import (
1011
Rotate,
@@ -15,7 +16,7 @@
1516
)
1617

1718

18-
class TestTransform(unittest.TestCase):
19+
class TestTransform(TestCaseMixin, unittest.TestCase):
1920
def test_to(self):
2021
tr = Translate(torch.FloatTensor([[1.0, 2.0, 3.0]]))
2122
R = torch.FloatTensor([[0.0, 1.0, 0.0], [0.0, 0.0, 1.0], [1.0, 0.0, 0.0]])
@@ -89,6 +90,22 @@ def test_translate(self):
8990
self.assertTrue(torch.allclose(points_out, points_out_expected))
9091
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
9192

93+
def test_rotate(self):
94+
R = so3_exponential_map(torch.randn((1, 3)))
95+
t = Transform3d().rotate(R)
96+
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
97+
1, 3, 3
98+
)
99+
normals = torch.tensor(
100+
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]]
101+
).view(1, 3, 3)
102+
points_out = t.transform_points(points)
103+
normals_out = t.transform_normals(normals)
104+
points_out_expected = torch.bmm(points, R)
105+
normals_out_expected = torch.bmm(normals, R)
106+
self.assertTrue(torch.allclose(points_out, points_out_expected))
107+
self.assertTrue(torch.allclose(normals_out, normals_out_expected))
108+
92109
def test_scale(self):
93110
t = Transform3d().scale(2.0).scale(0.5, 0.25, 1.0)
94111
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(
@@ -237,6 +254,103 @@ def test_inverse(self, batch_size=5):
237254
for m in (m1, m2, m3, m4):
238255
self.assertTrue(torch.allclose(m, m5, atol=1e-3))
239256

257+
def _check_indexed_transforms(self, t3d, t3d_selected, indices):
258+
t3d_matrix = t3d.get_matrix()
259+
t3d_selected_matrix = t3d_selected.get_matrix()
260+
for order_index, selected_index in indices:
261+
self.assertClose(
262+
t3d_matrix[selected_index], t3d_selected_matrix[order_index]
263+
)
264+
265+
def test_get_item(self, batch_size=5):
266+
device = torch.device("cuda:0")
267+
268+
matrices = torch.randn(
269+
size=[batch_size, 4, 4], device=device, dtype=torch.float32
270+
)
271+
272+
# init the Transforms3D class
273+
t3d = Transform3d(matrix=matrices)
274+
275+
# int index
276+
index = 1
277+
t3d_selected = t3d[index]
278+
self.assertEqual(len(t3d_selected), 1)
279+
self._check_indexed_transforms(t3d, t3d_selected, [(0, 1)])
280+
281+
# negative int index
282+
index = -1
283+
t3d_selected = t3d[index]
284+
self.assertEqual(len(t3d_selected), 1)
285+
self._check_indexed_transforms(t3d, t3d_selected, [(0, -1)])
286+
287+
# list index
288+
index = [1, 2]
289+
t3d_selected = t3d[index]
290+
self.assertEqual(len(t3d_selected), len(index))
291+
self._check_indexed_transforms(t3d, t3d_selected, enumerate(index))
292+
293+
# empty list index
294+
index = []
295+
t3d_selected = t3d[index]
296+
self.assertEqual(len(t3d_selected), 0)
297+
self.assertEqual(t3d_selected.get_matrix().nelement(), 0)
298+
299+
# slice index
300+
index = slice(0, 2, 1)
301+
t3d_selected = t3d[index]
302+
self.assertEqual(len(t3d_selected), 2)
303+
self._check_indexed_transforms(t3d, t3d_selected, [(0, 0), (1, 1)])
304+
305+
# empty slice index
306+
index = slice(0, 0, 1)
307+
t3d_selected = t3d[index]
308+
self.assertEqual(len(t3d_selected), 0)
309+
self.assertEqual(t3d_selected.get_matrix().nelement(), 0)
310+
311+
# bool tensor
312+
index = (torch.rand(batch_size) > 0.5).to(device)
313+
index[:2] = True # make sure smth is selected
314+
t3d_selected = t3d[index]
315+
self.assertEqual(len(t3d_selected), index.sum())
316+
self._check_indexed_transforms(
317+
t3d,
318+
t3d_selected,
319+
zip(
320+
torch.arange(index.sum()),
321+
torch.nonzero(index, as_tuple=False).squeeze(),
322+
),
323+
)
324+
325+
# all false bool tensor
326+
index = torch.zeros(batch_size).bool()
327+
t3d_selected = t3d[index]
328+
self.assertEqual(len(t3d_selected), 0)
329+
self.assertEqual(t3d_selected.get_matrix().nelement(), 0)
330+
331+
# int tensor
332+
index = torch.tensor([1, 2], dtype=torch.int64, device=device)
333+
t3d_selected = t3d[index]
334+
self.assertEqual(len(t3d_selected), index.numel())
335+
self._check_indexed_transforms(t3d, t3d_selected, enumerate(index.tolist()))
336+
337+
# negative int tensor
338+
index = -(torch.tensor([1, 2], dtype=torch.int64, device=device))
339+
t3d_selected = t3d[index]
340+
self.assertEqual(len(t3d_selected), index.numel())
341+
self._check_indexed_transforms(t3d, t3d_selected, enumerate(index.tolist()))
342+
343+
# invalid index
344+
for invalid_index in (
345+
torch.tensor([1, 0, 1], dtype=torch.float32, device=device), # float tensor
346+
1.2, # float index
347+
torch.tensor(
348+
[[1, 0, 1], [1, 0, 1]], dtype=torch.int32, device=device
349+
), # multidimensional tensor
350+
):
351+
with self.assertRaises(IndexError):
352+
t3d_selected = t3d[invalid_index]
353+
240354

241355
class TestTranslate(unittest.TestCase):
242356
def test_python_scalar(self):

0 commit comments

Comments
 (0)