Skip to content

Commit ce60d4b

Browse files
bottlerfacebook-github-bot
authored andcommitted
remove requires_grad from random rotations
Summary: Because rotations and (rotation) quaternions live on curved manifolds, it doesn't make sense to optimize them directly. Having a prominent option to require gradient on random ones may cause people to try, and isn't particularly useful. Reviewed By: theschnitz Differential Revision: D29160734 fbshipit-source-id: fc9e320672349fe334747c5b214655882a460a62
1 parent 31c448a commit ce60d4b

File tree

2 files changed

+12
-23
lines changed

2 files changed

+12
-23
lines changed

pytorch3d/transforms/rotation_conversions.py

+6-20
Original file line numberDiff line numberDiff line change
@@ -282,9 +282,7 @@ def matrix_to_euler_angles(matrix, convention: str):
282282
return torch.stack(o, -1)
283283

284284

285-
def random_quaternions(
286-
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
287-
):
285+
def random_quaternions(n: int, dtype: Optional[torch.dtype] = None, device=None):
288286
"""
289287
Generate random quaternions representing rotations,
290288
i.e. versors with nonnegative real part.
@@ -294,21 +292,17 @@ def random_quaternions(
294292
dtype: Type to return.
295293
device: Desired device of returned tensor. Default:
296294
uses the current device for the default tensor type.
297-
requires_grad: Whether the resulting tensor should have the gradient
298-
flag set.
299295
300296
Returns:
301297
Quaternions as tensor of shape (N, 4).
302298
"""
303-
o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
299+
o = torch.randn((n, 4), dtype=dtype, device=device)
304300
s = (o * o).sum(1)
305301
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
306302
return o
307303

308304

309-
def random_rotations(
310-
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
311-
):
305+
def random_rotations(n: int, dtype: Optional[torch.dtype] = None, device=None):
312306
"""
313307
Generate random rotations as 3x3 rotation matrices.
314308
@@ -317,35 +311,27 @@ def random_rotations(
317311
dtype: Type to return.
318312
device: Device of returned tensor. Default: if None,
319313
uses the current device for the default tensor type.
320-
requires_grad: Whether the resulting tensor should have the gradient
321-
flag set.
322314
323315
Returns:
324316
Rotation matrices as tensor of shape (n, 3, 3).
325317
"""
326-
quaternions = random_quaternions(
327-
n, dtype=dtype, device=device, requires_grad=requires_grad
328-
)
318+
quaternions = random_quaternions(n, dtype=dtype, device=device)
329319
return quaternion_to_matrix(quaternions)
330320

331321

332-
def random_rotation(
333-
dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
334-
):
322+
def random_rotation(dtype: Optional[torch.dtype] = None, device=None):
335323
"""
336324
Generate a single random 3x3 rotation matrix.
337325
338326
Args:
339327
dtype: Type to return
340328
device: Device of returned tensor. Default: if None,
341329
uses the current device for the default tensor type
342-
requires_grad: Whether the resulting tensor should have the gradient
343-
flag set
344330
345331
Returns:
346332
Rotation matrix as tensor of shape (3, 3).
347333
"""
348-
return random_rotations(1, dtype, device, requires_grad)[0]
334+
return random_rotations(1, dtype, device)[0]
349335

350336

351337
def standardize_quaternion(quaternions):

tests/test_rotation_conversions.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def test_to_quat(self):
7676

7777
def test_quat_grad_exists(self):
7878
"""Quaternion calculations are differentiable."""
79-
rotation = random_rotation(requires_grad=True)
79+
rotation = random_rotation()
80+
rotation.requires_grad = True
8081
modified = quaternion_to_matrix(matrix_to_quaternion(rotation))
8182
[g] = torch.autograd.grad(modified.sum(), rotation)
8283
self.assertTrue(torch.isfinite(g).all())
@@ -131,7 +132,8 @@ def test_to_euler(self):
131132

132133
def test_euler_grad_exists(self):
133134
"""Euler angle calculations are differentiable."""
134-
rotation = random_rotation(dtype=torch.float64, requires_grad=True)
135+
rotation = random_rotation(dtype=torch.float64)
136+
rotation.requires_grad = True
135137
for convention in self._all_euler_angle_conventions():
136138
euler_angles = matrix_to_euler_angles(rotation, convention)
137139
mdata = euler_angles_to_matrix(euler_angles, convention)
@@ -218,7 +220,8 @@ def test_to_axis_angle(self):
218220

219221
def test_quaternion_application(self):
220222
"""Applying a quaternion is the same as applying the matrix."""
221-
quaternions = random_quaternions(3, torch.float64, requires_grad=True)
223+
quaternions = random_quaternions(3, torch.float64)
224+
quaternions.requires_grad = True
222225
matrices = quaternion_to_matrix(quaternions)
223226
points = torch.randn(3, 3, dtype=torch.float64, requires_grad=True)
224227
transform1 = quaternion_apply(quaternions, points)

0 commit comments

Comments
 (0)