Skip to content

Commit 90dc7a0

Browse files
davnov134facebook-github-bot
authored andcommitted
Initialization of Transform3D with a custom matrix.
Summary: Allows to initialize a Transform3D object with a batch of user-defined transformation matrices: ``` t = Transform3D(matrix=torch.randn(2, 4, 4)) ``` Reviewed By: nikhilaravi Differential Revision: D20693475 fbshipit-source-id: dccc49b2ca4c19a034844c63463953ba8f52c1bc
1 parent e37085d commit 90dc7a0

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

pytorch3d/transforms/transform3d.py

+31-2
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,37 @@ class Transform3d:
134134
135135
"""
136136

137-
def __init__(self, dtype=torch.float32, device="cpu"):
138-
self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4)
137+
def __init__(
138+
self,
139+
dtype: torch.dtype = torch.float32,
140+
device="cpu",
141+
matrix: Optional[torch.Tensor] = None,
142+
):
143+
"""
144+
Args:
145+
dtype: The data type of the transformation matrix.
146+
to be used if `matrix = None`.
147+
device: The device for storing the implemented transformation.
148+
If `matrix != None`, uses the device of input `matrix`.
149+
matrix: A tensor of shape (4, 4) or of shape (minibatch, 4, 4)
150+
representing the 4x4 3D transformation matrix.
151+
If `None`, initializes with identity using
152+
the specified `device` and `dtype`.
153+
"""
154+
155+
if matrix is None:
156+
self._matrix = torch.eye(4, dtype=dtype, device=device).view(1, 4, 4)
157+
else:
158+
if matrix.ndim not in (2, 3):
159+
raise ValueError('"matrix" has to be a 2- or a 3-dimensional tensor.')
160+
if matrix.shape[-2] != 4 or matrix.shape[-1] != 4:
161+
raise ValueError(
162+
'"matrix" has to be a tensor of shape (minibatch, 4, 4)'
163+
)
164+
# set the device from matrix
165+
device = matrix.device
166+
self._matrix = matrix.view(-1, 4, 4)
167+
139168
self._transforms = [] # store transforms to compose
140169
self._lu = None
141170
self.device = device

tests/test_transforms.py

+13
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,19 @@ def test_clone(self):
5757
matrix2 = t_pair[1].get_matrix()
5858
self.assertTrue(torch.allclose(matrix1, matrix2))
5959

60+
def test_init_with_custom_matrix(self):
61+
for matrix in (torch.randn(10, 4, 4), torch.randn(4, 4)):
62+
t = Transform3d(matrix=matrix)
63+
self.assertTrue(t.device == matrix.device)
64+
self.assertTrue(t._matrix.dtype == matrix.dtype)
65+
self.assertTrue(torch.allclose(t._matrix, matrix.view(t._matrix.shape)))
66+
67+
def test_init_with_custom_matrix_errors(self):
68+
bad_shapes = [[10, 5, 4], [3, 4], [10, 4, 4, 1], [10, 4, 4, 2], [4, 4, 4, 3]]
69+
for bad_shape in bad_shapes:
70+
matrix = torch.randn(*bad_shape).float()
71+
self.assertRaises(ValueError, Transform3d, matrix=matrix)
72+
6073
def test_translate(self):
6174
t = Transform3d().translate(1, 2, 3)
6275
points = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.5, 0.5, 0.0]]).view(

0 commit comments

Comments
 (0)