Skip to content

Commit ac26841

Browse files
committed
fix(//py): Fix trtorch.Device alternate contructor options
There were issues setting fields of trtorch.Device via kwargs, this patch should resolve those and verify that they work Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 91f21d8 commit ac26841

File tree

2 files changed

+60
-7
lines changed

2 files changed

+60
-7
lines changed

Diff for: py/trtorch/Device.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import torch
22

33
from trtorch import _types
4-
import logging
4+
import trtorch.logging
55
import trtorch._C
66

77
import warnings
@@ -54,23 +54,27 @@ def __init__(self, *args, **kwargs):
5454
else:
5555
self.dla_core = id
5656
self.gpu_id = 0
57-
logging.log(logging.log.Level.Warning,
58-
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
57+
trtorch.logging.log(trtorch.logging.Level.Warning,
58+
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
5959

6060
elif len(args) == 0:
61-
if not "gpu_id" in kwargs or not "dla_core" in kwargs:
61+
if "gpu_id" in kwargs or "dla_core" in kwargs:
6262
if "dla_core" in kwargs:
6363
self.device_type = _types.DeviceType.DLA
6464
self.dla_core = kwargs["dla_core"]
6565
if "gpu_id" in kwargs:
6666
self.gpu_id = kwargs["gpu_id"]
6767
else:
6868
self.gpu_id = 0
69-
logging.log(logging.log.Level.Warning,
70-
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
69+
trtorch.logging.log(trtorch.logging.Level.Warning,
70+
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
7171
else:
7272
self.gpu_id = kwargs["gpu_id"]
73-
self.device_type == _types.DeviceType.GPU
73+
self.device_type = _types.DeviceType.GPU
74+
else:
75+
raise ValueError(
76+
"Either gpu_id or dla_core or both must be defined if no string with device specs is provided as an arg"
77+
)
7478

7579
else:
7680
raise ValueError(
@@ -80,6 +84,7 @@ def __init__(self, *args, **kwargs):
8084
if "allow_gpu_fallback" in kwargs:
8185
if not isinstance(kwargs["allow_gpu_fallback"], bool):
8286
raise TypeError("allow_gpu_fallback must be a bool")
87+
self.allow_gpu_fallback = kwargs["allow_gpu_fallback"]
8388

8489
def __str__(self) -> str:
8590
return "Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) \

Diff for: tests/py/test_api.py

+48
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,53 @@ def test_is_colored_output_on(self):
230230
self.assertTrue(color)
231231

232232

233+
class TestDevice(unittest.TestCase):
234+
235+
def test_from_string_constructor(self):
236+
device = trtorch.Device("cuda:0")
237+
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
238+
self.assertEqual(device.gpu_id, 0)
239+
240+
device = trtorch.Device("gpu:1")
241+
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
242+
self.assertEqual(device.gpu_id, 1)
243+
244+
def test_from_string_constructor_dla(self):
245+
device = trtorch.Device("dla:0")
246+
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
247+
self.assertEqual(device.gpu_id, 0)
248+
self.assertEqual(device.dla_core, 0)
249+
250+
device = trtorch.Device("dla:1", allow_gpu_fallback=True)
251+
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
252+
self.assertEqual(device.gpu_id, 0)
253+
self.assertEqual(device.dla_core, 1)
254+
self.assertEqual(device.allow_gpu_fallback, True)
255+
256+
def test_kwargs_gpu(self):
257+
device = trtorch.Device(gpu_id=0)
258+
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
259+
self.assertEqual(device.gpu_id, 0)
260+
261+
def test_kwargs_dla_and_settings(self):
262+
device = trtorch.Device(dla_core=1, allow_gpu_fallback=False)
263+
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
264+
self.assertEqual(device.gpu_id, 0)
265+
self.assertEqual(device.dla_core, 1)
266+
self.assertEqual(device.allow_gpu_fallback, False)
267+
268+
device = trtorch.Device(gpu_id=1, dla_core=0, allow_gpu_fallback=True)
269+
self.assertEqual(device.device_type, trtorch.DeviceType.DLA)
270+
self.assertEqual(device.gpu_id, 1)
271+
self.assertEqual(device.dla_core, 0)
272+
self.assertEqual(device.allow_gpu_fallback, True)
273+
274+
def test_from_torch(self):
275+
device = trtorch.Device._from_torch_device(torch.device("cuda:0"))
276+
self.assertEqual(device.device_type, trtorch.DeviceType.GPU)
277+
self.assertEqual(device.gpu_id, 0)
278+
279+
233280
def test_suite():
234281
suite = unittest.TestSuite()
235282
suite.addTest(unittest.makeSuite(TestLoggingAPIs))
@@ -242,6 +289,7 @@ def test_suite():
242289
suite.addTest(
243290
TestModuleFallbackToTorch.parametrize(TestModuleFallbackToTorch, model=models.resnet18(pretrained=True)))
244291
suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport))
292+
suite.addTest(unittest.makeSuite(TestDevice))
245293

246294
return suite
247295

0 commit comments

Comments
 (0)