Skip to content

Commit d99169f

Browse files
committed
feat(//py): add user level device class in py for embed engine
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent df87de3 commit d99169f

File tree

8 files changed

+155
-30
lines changed

8 files changed

+155
-30
lines changed

Diff for: py/trtorch/Device.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import torch
2+
3+
from trtorch import _types
4+
import logging
5+
import trtorch._C
6+
7+
import warnings
8+
9+
10+
class Device(object):
11+
"""
12+
Defines a device that can be used to specify target devices for engines
13+
14+
Attributes:
15+
device_type (trtorch.DeviceType): Target device type (GPU or DLA). Set implicitly based on if dla_core is specified.
16+
gpu_id (int): Device ID for target GPU
17+
dla_core (int): Core ID for target DLA core
18+
allow_gpu_fallback (bool): Whether falling back to GPU if DLA cannot support an op should be allowed
19+
"""
20+
21+
device_type = None
22+
gpu_id = -1
23+
dla_core = -1
24+
allow_gpu_fallback = False
25+
26+
def __init__(self, *args, **kwargs):
27+
""" __init__ Method for trtorch.Device
28+
29+
Device accepts one of a few construction patterns
30+
31+
Args:
32+
spec (str): String with device spec e.g. "dla:0" for dla, core_id 0
33+
34+
Keyword Arguments:
35+
gpu_id (int): ID of target GPU (will get overrided if dla_core is specified to the GPU managing DLA). If specified, no positional arguments should be provided
36+
dla_core (int): ID of target DLA core. If specified, no positional arguments should be provided.
37+
allow_gpu_fallback (bool): Allow TensorRT to schedule operations on GPU if they are not supported on DLA (ignored if device type is not DLA)
38+
39+
Examples:
40+
- Device("gpu:1")
41+
- Device("cuda:1")
42+
- Device("dla:0", allow_gpu_fallback=True)
43+
- Device(gpu_id=0, dla_core=0, allow_gpu_fallback=True)
44+
- Device(dla_core=0, allow_gpu_fallback=True)
45+
- Device(gpu_id=1)
46+
"""
47+
if len(args) == 1:
48+
if not isinstance(args[0], str):
49+
raise TypeError("When specifying Device through positional argument, argument must be str")
50+
else:
51+
(self.device_type, id) = Device._parse_device_str(args[0])
52+
if self.device_type == _types.DeviceType.GPU:
53+
self.gpu_id = id
54+
else:
55+
self.dla_core = id
56+
self.gpu_id = 0
57+
logging.log(logging.log.Level.Warning,
58+
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
59+
60+
elif len(args) == 0:
61+
if not "gpu_id" in kwargs or not "dla_core" in kwargs:
62+
if "dla_core" in kwargs:
63+
self.device_type = _types.DeviceType.DLA
64+
self.dla_core = kwargs["dla_core"]
65+
if "gpu_id" in kwargs:
66+
self.gpu_id = kwargs["gpu_id"]
67+
else:
68+
self.gpu_id = 0
69+
logging.log(logging.log.Level.Warning,
70+
"Setting GPU id to 0 for device because device 0 manages DLA on Xavier")
71+
else:
72+
self.gpu_id = kwargs["gpu_id"]
73+
self.device_type == _types.DeviceType.GPU
74+
75+
else:
76+
raise ValueError(
77+
"Unexpected number of positional arguments for class Device \n Found {} arguments, expected either zero or a single positional arguments"
78+
.format(len(args)))
79+
80+
if "allow_gpu_fallback" in kwargs:
81+
if not isinstance(kwargs["allow_gpu_fallback"], bool):
82+
raise TypeError("allow_gpu_fallback must be a bool")
83+
84+
def __str__(self) -> str:
85+
return "Device(type={}, gpu_id={}".format(self.device_type, self.gpu_id) \
86+
+ ")" if self.device_type == _types.DeviceType.GPU else ", dla_core={}, allow_gpu_fallback={}".format(self.dla_core, self.allow_gpu_fallback)
87+
88+
def _to_internal(self) -> trtorch._C.Device:
89+
internal_dev = trtorch._C.Device()
90+
internal_dev.device_type = self.device_type
91+
internal_dev.gpu_id = self.gpu_id
92+
internal_dev.dla_core = self.dla_core
93+
internal_dev.allow_gpu_fallback = self.allow_gpu_fallback
94+
return internal_dev
95+
96+
@classmethod
97+
def _from_torch_device(cls, torch_dev: torch.device):
98+
if torch_dev.type != 'cuda':
99+
raise ValueError("Torch Device specs must have type \"cuda\"")
100+
gpu_id = torch_dev.index
101+
return cls(gpu_id=gpu_id)
102+
103+
@staticmethod
104+
def _parse_device_str(s):
105+
s = s.lower()
106+
spec = s.split(':')
107+
if spec[0] == "gpu" or spec[0] == "cuda":
108+
return (_types.DeviceType.GPU, int(spec[1]))
109+
elif spec[0] == "dla":
110+
return (_types.DeviceType.DLA, int(spec[1]))

Diff for: py/trtorch/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from trtorch._types import *
1515
from trtorch import logging
1616
from trtorch.Input import Input
17+
from trtorch.Device import Device
1718

1819

1920
def _register_with_torch():

Diff for: py/trtorch/_compile_spec.py

+29-20
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import trtorch._C
44
from trtorch import _types
55
from trtorch.Input import Input
6+
from trtorch.Device import Device
67

78
import warnings
89

@@ -101,27 +102,35 @@ def _parse_device_type(device: Any) -> _types.DeviceType:
101102
str(type(device)))
102103

103104

104-
def _parse_device(device_info: Dict[str, Any]) -> trtorch._C.Device:
105-
info = trtorch._C.Device()
106-
if "device_type" not in device_info:
107-
raise KeyError("Device type is required parameter")
105+
def _parse_device(device_info: Any) -> trtorch._C.Device:
106+
if isinstance(device_info, dict):
107+
info = trtorch._C.Device()
108+
if "device_type" not in device_info:
109+
raise KeyError("Device type is required parameter")
110+
else:
111+
assert isinstance(device_info["device_type"], _types.DeviceType)
112+
info.device_type = _parse_device_type(device_info["device_type"])
113+
114+
if "gpu_id" in device_info:
115+
assert isinstance(device_info["gpu_id"], int)
116+
info.gpu_id = device_info["gpu_id"]
117+
118+
if "dla_core" in device_info:
119+
assert isinstance(device_info["dla_core"], int)
120+
info.dla_core = device_info["dla_core"]
121+
122+
if "allow_gpu_fallback" in device_info:
123+
assert isinstance(device_info["allow_gpu_fallback"], bool)
124+
info.allow_gpu_fallback = device_info["allow_gpu_fallback"]
125+
126+
return info
127+
elif isinstance(device_info, Device):
128+
return device_info._to_internal()
129+
elif isinstance(device_info, torch.device):
130+
return (Device._from_torch_device(device_info))._to_internal()
108131
else:
109-
assert isinstance(device_info["device_type"], _types.DeviceType)
110-
info.device_type = _parse_device_type(device_info["device_type"])
111-
112-
if "gpu_id" in device_info:
113-
assert isinstance(device_info["gpu_id"], int)
114-
info.gpu_id = device_info["gpu_id"]
115-
116-
if "dla_core" in device_info:
117-
assert isinstance(device_info["dla_core"], int)
118-
info.dla_core = device_info["dla_core"]
119-
120-
if "allow_gpu_fallback" in device_info:
121-
assert isinstance(device_info["allow_gpu_fallback"], bool)
122-
info.allow_gpu_fallback = device_info["allow_gpu_fallback"]
123-
124-
return info
132+
raise ValueError(
133+
"Unsupported data for device specification. Expected either a dict, trtorch.Device or torch.Device")
125134

126135

127136
def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> trtorch._C.TorchFallback:

Diff for: py/trtorch/_compiler.py

+7-7
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import trtorch._C
66
from trtorch._compile_spec import _parse_compile_spec
77
from trtorch._version import __version__
8+
from trtorch.Device import Device
89
from types import FunctionType
910

1011

@@ -42,8 +43,7 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri
4243
"dla_core": 0, # (DLA only) Target dla core id to run engine
4344
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
4445
},
45-
"op_precision": torch.half, # Operating precision set to FP16
46-
"input_dtypes": [torch.float32] # List of datatypes that should be configured for each input. Supported options torch.{float|half|int8|int32|bool}.
46+
"enabled_precisions": {torch.float, torch.half}, # Enabling FP16 kernels
4747
"refit": false, # enable refit
4848
"debug": false, # enable debuggable engine
4949
"strict_types": false, # kernels should strictly run in operating precision
@@ -61,7 +61,7 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri
6161
}
6262
}
6363
64-
Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using
64+
Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
6565
torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum
6666
to select device type.
6767
@@ -110,7 +110,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
110110
"dla_core": 0, # (DLA only) Target dla core id to run engine
111111
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
112112
},
113-
"op_precision": torch.half, # Operating precision set to FP16
113+
"enabled_precisions": {torch.float, torch.half}, # Enabling FP16 kernels
114114
# List of datatypes that should be configured for each input. Supported options torch.{float|half|int8|int32|bool}.
115115
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
116116
"refit": false, # enable refit
@@ -123,7 +123,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
123123
"max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
124124
}
125125
126-
Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using
126+
Input Sizes can be specified as torch sizes, tuples or lists. dtypes can be specified using
127127
torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum
128128
to select device type.
129129
@@ -137,7 +137,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
137137
return trtorch._C.convert_graph_to_trt_engine(module._c, method_name, _parse_compile_spec(compile_spec))
138138

139139

140-
def embed_engine_in_new_module(serialized_engine: bytes) -> torch.jit.ScriptModule:
140+
def embed_engine_in_new_module(serialized_engine: bytes, device: Device) -> torch.jit.ScriptModule:
141141
"""Takes a pre-built serialized TensorRT engine and embeds it within a TorchScript module
142142
143143
Takes a pre-built serialied TensorRT engine (as bytes) and embeds it within a TorchScript module.
@@ -153,7 +153,7 @@ def embed_engine_in_new_module(serialized_engine: bytes) -> torch.jit.ScriptModu
153153
Returns:
154154
torch.jit.ScriptModule: New TorchScript module with engine embedded
155155
"""
156-
cpp_mod = trtorch._C.embed_engine_in_new_module(serialized_engine)
156+
cpp_mod = trtorch._C.embed_engine_in_new_module(serialized_engine, device._to_internal())
157157
return torch.jit._recursive.wrap_cpp_module(cpp_mod)
158158

159159

Diff for: py/trtorch/csrc/tensorrt_classes.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ nvinfer1::DeviceType toTRTDeviceType(DeviceType value) {
115115
}
116116
}
117117

118+
core::runtime::CudaDevice Device::toInternalRuntimeDevice() {
119+
return core::runtime::CudaDevice(gpu_id, toTRTDeviceType(device_type));
120+
}
121+
118122
std::string Device::to_str() {
119123
std::stringstream ss;
120124
std::string fallback = allow_gpu_fallback ? "True" : "False";

Diff for: py/trtorch/csrc/tensorrt_classes.h

+1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ struct Device : torch::CustomClassHolder {
7979
ADD_FIELD_GET_SET(dla_core, int64_t);
8080
ADD_FIELD_GET_SET(allow_gpu_fallback, bool);
8181

82+
core::runtime::CudaDevice toInternalRuntimeDevice();
8283
std::string to_str();
8384
};
8485

Diff for: py/trtorch/csrc/trtorch_py.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,8 @@ bool CheckMethodOperatorSupport(const torch::jit::Module& module, const std::str
119119
return core::CheckMethodOperatorSupport(module, method_name);
120120
}
121121

122-
torch::jit::Module EmbedEngineInNewModule(const py::bytes& engine, core::runtime::CudaDevice& device) {
123-
return core::EmbedEngineInNewModule(engine, device);
122+
torch::jit::Module EmbedEngineInNewModule(const py::bytes& engine, Device& device) {
123+
return core::EmbedEngineInNewModule(engine, device.toInternalRuntimeDevice());
124124
}
125125

126126
std::string get_build_info() {

Diff for: tests/py/test_api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def test_pt_to_trt_to_pt(self):
162162
}
163163

164164
trt_engine = trtorch.convert_method_to_trt_engine(self.ts_model, "forward", compile_spec)
165-
trt_mod = trtorch.embed_engine_in_new_module(trt_engine)
165+
trt_mod = trtorch.embed_engine_in_new_module(trt_engine, trtorch.Device("cuda:0"))
166166
same = (trt_mod(self.input) - self.ts_model(self.input)).abs().max()
167167
self.assertTrue(same < 2e-3)
168168

0 commit comments

Comments
 (0)