@@ -73,16 +73,21 @@ def _parse_op_precision(precision: Any) -> _types.dtype:
73
73
74
74
def _parse_device_type (device : Any ) -> _types .DeviceType :
75
75
if isinstance (device , torch .device ):
76
- if torch . device .type == 'cuda' :
76
+ if device .type == 'cuda' :
77
77
return _types .DeviceType .gpu
78
78
else :
79
- raise TypeError ("Valid device choices are GPU (and DLA if on Jetson platforms) however got device type" + str (device .type ))
80
-
79
+ ValueError ("Got a device type other than GPU or DLA (type: " + str (device .type ) + ")" )
81
80
elif isinstance (device , _types .DeviceType ):
82
81
return device
83
-
82
+ elif isinstance (device , str ):
83
+ if device == "gpu" or device == "GPU" :
84
+ return _types .DeviceType .gpu
85
+ elif device == "dla" or device == "DLA" :
86
+ return _types .DeviceType .dla
87
+ else :
88
+ ValueError ("Got a device type other than GPU or DLA (type: " + str (device ) + ")" )
84
89
else :
85
- raise TypeError ("Device specification must be of type torch.device or trtorch.DeviceType, but got: " + str (type (device )))
90
+ raise TypeError ("Device specification must be of type torch.device, string or trtorch.DeviceType, but got: " + str (type (device )))
86
91
87
92
def _parse_compile_spec (compile_spec : Dict [str , Any ]) -> trtorch ._C .CompileSpec :
88
93
info = trtorch ._C .CompileSpec ()
@@ -110,11 +115,11 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
110
115
assert isinstance (compile_spec ["allow_gpu_fallback" ], bool )
111
116
info .allow_gpu_fallback = compile_spec ["allow_gpu_fallback" ]
112
117
113
- if "device " in compile_spec :
114
- info .device = _parse_device_type (compile_spec ["device " ])
118
+ if "device_type " in compile_spec :
119
+ info .device = _parse_device_type (compile_spec ["device_type " ])
115
120
116
121
if "capability" in compile_spec :
117
- assert isinstance (compile_spec ["capability" ], type .EngineCapability )
122
+ assert isinstance (compile_spec ["capability" ], _types .EngineCapability )
118
123
info .capability = compile_spec ["capability" ]
119
124
120
125
if "num_min_timing_iters" in compile_spec :
@@ -133,4 +138,74 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
133
138
assert type (compile_spec ["max_batch_size" ]) is int
134
139
info .max_batch_size = compile_spec ["max_batch_size" ]
135
140
136
- return info
141
+ return info
142
+
143
+ def TensorRTCompileSpec (compile_spec : Dict [str , Any ]):
144
+ """
145
+ Utility to create a formated spec dictionary for using the PyTorch TensorRT backend
146
+
147
+ Args:
148
+ compile_spec (dict): Compilation settings including operating precision, target device, etc.
149
+ One key is required which is ``input_shapes``, describing the input sizes or ranges for inputs
150
+ to the graph. All other keys are optional. Entries for each method to be compiled.
151
+
152
+ .. code-block:: py
153
+
154
+ CompileSpec = {
155
+ "forward" : trtorch.TensorRTCompileSpec({
156
+ "input_shapes": [
157
+ (1, 3, 224, 224), # Static input shape for input #1
158
+ {
159
+ "min": (1, 3, 224, 224),
160
+ "opt": (1, 3, 512, 512),
161
+ "max": (1, 3, 1024, 1024)
162
+ } # Dynamic input shape for input #2
163
+ ],
164
+ "op_precision": torch.half, # Operating precision set to FP16
165
+ "refit": false, # enable refit
166
+ "debug": false, # enable debuggable engine
167
+ "strict_types": false, # kernels should strictly run in operating precision
168
+ "allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
169
+ "device": torch.device("cuda"), # Type of device to run engine on (for DLA use trtorch.DeviceType.DLA)
170
+ "capability": trtorch.EngineCapability.DEFAULT, # Restrict kernel selection to safe gpu kernels or safe dla kernels
171
+ "num_min_timing_iters": 2, # Number of minimization timing iterations used to select kernels
172
+ "num_avg_timing_iters": 1, # Number of averaging timing iterations used to select kernels
173
+ "workspace_size": 0, # Maximum size of workspace given to TensorRT
174
+ "max_batch_size": 0, # Maximum batch size (must be >= 1 to be set, 0 means not set)
175
+ })
176
+ }
177
+
178
+ Input Sizes can be specified as torch sizes, tuples or lists. Op precisions can be specified using
179
+ torch datatypes or trtorch datatypes and you can use either torch devices or the trtorch device type enum
180
+ to select device type.
181
+
182
+ Returns:
183
+ torch.classes.tensorrt.CompileSpec: List of methods and formated spec objects to be provided to ``torch._C._jit_to_tensorrt``
184
+ """
185
+
186
+ parsed_spec = _parse_compile_spec (compile_spec )
187
+
188
+ backend_spec = torch .classes .tensorrt .CompileSpec ()
189
+
190
+ for i in parsed_spec .input_ranges :
191
+ ir = torch .classes .tensorrt .InputRange ()
192
+ ir .set_min (i .min )
193
+ ir .set_opt (i .opt )
194
+ ir .set_max (i .max )
195
+ backend_spec .append_input_range (ir )
196
+
197
+ backend_spec .set_op_precision (int (parsed_spec .op_precision ))
198
+ backend_spec .set_refit (parsed_spec .refit )
199
+ backend_spec .set_debug (parsed_spec .debug )
200
+ backend_spec .set_refit (parsed_spec .refit )
201
+ backend_spec .set_strict_types (parsed_spec .strict_types )
202
+ backend_spec .set_allow_gpu_fallback (parsed_spec .allow_gpu_fallback )
203
+ backend_spec .set_device (int (parsed_spec .device ))
204
+ backend_spec .set_capability (int (parsed_spec .capability ))
205
+ backend_spec .set_num_min_timing_iters (parsed_spec .num_min_timing_iters )
206
+ backend_spec .set_num_avg_timing_iters (parsed_spec .num_avg_timing_iters )
207
+ backend_spec .set_workspace_size (parsed_spec .workspace_size )
208
+ backend_spec .set_max_batch_size (parsed_spec .max_batch_size )
209
+
210
+ return backend_spec
211
+
0 commit comments