Skip to content

Commit 0c02ae9

Browse files
shapovalovfacebook-github-bot
authored andcommitted
Adding utility methods to TensorProperties
Summary: Context: in the code we are releasing with CO3D dataset, we use `cuda()` on TensorProperties like Pointclouds and Cameras where we recursively move batch to a GPU. It would be good to push it to a release so we don’t need to depend on the nightly build. Additionally, I aligned the logic of `.to("cuda")` without device index to the one of `torch.Tensor` where the current device is populated to index. It should not affect any actual use cases but some tests had to be changed. Reviewed By: bottler Differential Revision: D29659529 fbshipit-source-id: abe58aeaca14bacc68da3e6cf5ae07df3353e3ce
1 parent fa44a05 commit 0c02ae9

File tree

6 files changed

+31
-9
lines changed

6 files changed

+31
-9
lines changed

pytorch3d/common/types.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,21 @@
1515
def make_device(device: Device) -> torch.device:
1616
"""
1717
Makes an actual torch.device object from the device specified as
18-
either a string or torch.device object.
18+
either a string or torch.device object. If the device is `cuda` without
19+
a specific index, the index of the current device is assigned.
1920
2021
Args:
2122
device: Device (as str or torch.device)
2223
2324
Returns:
2425
A matching torch.device object
2526
"""
26-
return torch.device(device) if isinstance(device, str) else device
27+
device = torch.device(device) if isinstance(device, str) else device
28+
if device.type == "cuda" and device.index is None: # pyre-ignore[16]
29+
# If cuda but with no index, then the current cuda device is indicated.
30+
# In that case, we fix to that device
31+
device = torch.device(f"cuda:{torch.cuda.current_device()}")
32+
return device
2733

2834

2935
def get_device(x, device: Optional[Device] = None) -> torch.device:

pytorch3d/renderer/utils.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import copy
99
import inspect
1010
import warnings
11-
from typing import Any, Union
11+
from typing import Any, Optional, Union
1212

1313
import numpy as np
1414
import torch
@@ -174,6 +174,12 @@ def to(self, device: Device = "cpu") -> "TensorProperties":
174174
setattr(self, k, v.to(device_))
175175
return self
176176

177+
def cpu(self) -> "TensorProperties":
178+
return self.to("cpu")
179+
180+
def cuda(self, device: Optional[int] = None) -> "TensorProperties":
181+
return self.to(f"cuda:{device}" if device is not None else "cuda")
182+
177183
def clone(self, other) -> "TensorProperties":
178184
"""
179185
Update the tensor properties of other with the cloned properties of self.

tests/test_meshes.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -709,9 +709,9 @@ def test_to(self):
709709
self.assertEqual(cpu_device, mesh.device)
710710
self.assertIs(mesh, converted_mesh)
711711

712-
cuda_device = torch.device("cuda")
712+
cuda_device = torch.device("cuda:0")
713713

714-
converted_mesh = mesh.to("cuda")
714+
converted_mesh = mesh.to("cuda:0")
715715
self.assertEqual(cuda_device, converted_mesh.device)
716716
self.assertEqual(cpu_device, mesh.device)
717717
self.assertIsNot(mesh, converted_mesh)

tests/test_rendering_utils.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,17 @@ def test_to(self):
3939
example = TensorPropertiesTestClass(x=10.0, y=(100.0, 200.0))
4040
device = torch.device("cuda:0")
4141
new_example = example.to(device=device)
42-
self.assertTrue(new_example.device == device)
42+
self.assertEqual(new_example.device, device)
43+
44+
example_cpu = example.cpu()
45+
self.assertEqual(example_cpu.device, torch.device("cpu"))
46+
47+
example_gpu = example.cuda()
48+
self.assertEqual(example_gpu.device.type, "cuda")
49+
self.assertIsNotNone(example_gpu.device.index)
50+
51+
example_gpu1 = example.cuda(1)
52+
self.assertEqual(example_gpu1.device, torch.device("cuda:1"))
4353

4454
def test_clone(self):
4555
# Check clone method

tests/test_shader.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
class TestShader(TestCaseMixin, unittest.TestCase):
2323
def test_to(self):
2424
cpu_device = torch.device("cpu")
25-
cuda_device = torch.device("cuda")
25+
cuda_device = torch.device("cuda:0")
2626

2727
R, T = look_at_view_transform()
2828

tests/test_transforms.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -50,9 +50,9 @@ def test_to(self):
5050
self.assertEqual(torch.float32, t.dtype)
5151
self.assertIsNot(t, cpu_t)
5252

53-
cuda_device = torch.device("cuda")
53+
cuda_device = torch.device("cuda:0")
5454

55-
cuda_t = t.to("cuda")
55+
cuda_t = t.to("cuda:0")
5656
self.assertEqual(cuda_device, cuda_t.device)
5757
self.assertEqual(cpu_device, t.device)
5858
self.assertEqual(torch.float32, cuda_t.dtype)

0 commit comments

Comments
 (0)