@@ -230,6 +230,53 @@ def test_is_colored_output_on(self):
230
230
self .assertTrue (color )
231
231
232
232
233
+ class TestDevice (unittest .TestCase ):
234
+
235
+ def test_from_string_constructor (self ):
236
+ device = trtorch .Device ("cuda:0" )
237
+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
238
+ self .assertEqual (device .gpu_id , 0 )
239
+
240
+ device = trtorch .Device ("gpu:1" )
241
+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
242
+ self .assertEqual (device .gpu_id , 1 )
243
+
244
+ def test_from_string_constructor_dla (self ):
245
+ device = trtorch .Device ("dla:0" )
246
+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
247
+ self .assertEqual (device .gpu_id , 0 )
248
+ self .assertEqual (device .dla_core , 0 )
249
+
250
+ device = trtorch .Device ("dla:1" , allow_gpu_fallback = True )
251
+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
252
+ self .assertEqual (device .gpu_id , 0 )
253
+ self .assertEqual (device .dla_core , 1 )
254
+ self .assertEqual (device .allow_gpu_fallback , True )
255
+
256
+ def test_kwargs_gpu (self ):
257
+ device = trtorch .Device (gpu_id = 0 )
258
+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
259
+ self .assertEqual (device .gpu_id , 0 )
260
+
261
+ def test_kwargs_dla_and_settings (self ):
262
+ device = trtorch .Device (dla_core = 1 , allow_gpu_fallback = False )
263
+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
264
+ self .assertEqual (device .gpu_id , 0 )
265
+ self .assertEqual (device .dla_core , 1 )
266
+ self .assertEqual (device .allow_gpu_fallback , False )
267
+
268
+ device = trtorch .Device (gpu_id = 1 , dla_core = 0 , allow_gpu_fallback = True )
269
+ self .assertEqual (device .device_type , trtorch .DeviceType .DLA )
270
+ self .assertEqual (device .gpu_id , 1 )
271
+ self .assertEqual (device .dla_core , 0 )
272
+ self .assertEqual (device .allow_gpu_fallback , True )
273
+
274
+ def test_from_torch (self ):
275
+ device = trtorch .Device ._from_torch_device (torch .device ("cuda:0" ))
276
+ self .assertEqual (device .device_type , trtorch .DeviceType .GPU )
277
+ self .assertEqual (device .gpu_id , 0 )
278
+
279
+
233
280
def test_suite ():
234
281
suite = unittest .TestSuite ()
235
282
suite .addTest (unittest .makeSuite (TestLoggingAPIs ))
@@ -242,6 +289,7 @@ def test_suite():
242
289
suite .addTest (
243
290
TestModuleFallbackToTorch .parametrize (TestModuleFallbackToTorch , model = models .resnet18 (pretrained = True )))
244
291
suite .addTest (unittest .makeSuite (TestCheckMethodOpSupport ))
292
+ suite .addTest (unittest .makeSuite (TestDevice ))
245
293
246
294
return suite
247
295
0 commit comments