1
- from typing import List , Dict , Any , Set , Union , Callable , TypeGuard
1
+ from typing import List , Any , Set , Callable , TypeGuard , Optional
2
2
3
3
import torch_tensorrt .ts
4
4
9
9
import torch .fx
10
10
from enum import Enum
11
11
12
- import torch_tensorrt .fx
13
12
from torch_tensorrt .fx import InputTensorSpec
14
13
from torch_tensorrt .fx .utils import LowerPrecision
15
14
15
+ from torch_tensorrt .dynamo .compile import compile as dynamo_compile
16
+ from torch_tensorrt .fx .lower import compile as fx_compile
17
+ from torch_tensorrt .ts ._compiler import compile as torchscript_compile
18
+
16
19
17
20
def _non_fx_input_interface (
18
21
inputs : List [Input | torch .Tensor | InputTensorSpec ],
19
22
) -> TypeGuard [List [Input | torch .Tensor ]]:
20
- return all ([ isinstance (i , torch .Tensor | Input ) for i in inputs ] )
23
+ return all (isinstance (i , torch .Tensor | Input ) for i in inputs )
21
24
22
25
23
26
def _fx_input_interface (
24
27
inputs : List [Input | torch .Tensor | InputTensorSpec ],
25
28
) -> TypeGuard [List [InputTensorSpec | torch .Tensor ]]:
26
- return all ([ isinstance (i , torch .Tensor | InputTensorSpec ) for i in inputs ] )
29
+ return all (isinstance (i , torch .Tensor | InputTensorSpec ) for i in inputs )
27
30
28
31
29
32
class _IRType (Enum ):
@@ -58,10 +61,10 @@ def _parse_module_type(module: Any) -> _ModuleType:
58
61
59
62
60
63
def _get_target_ir (module_type : _ModuleType , ir : str ) -> _IRType :
61
- module_is_tsable = any ([ module_type == t for t in [_ModuleType .nn , _ModuleType .ts ] ])
62
- module_is_fxable = any ([ module_type == t for t in [_ModuleType .nn , _ModuleType .fx ] ])
64
+ module_is_tsable = any (module_type == t for t in [_ModuleType .nn , _ModuleType .ts ])
65
+ module_is_fxable = any (module_type == t for t in [_ModuleType .nn , _ModuleType .fx ])
63
66
64
- ir_targets_torchscript = any ([ ir == opt for opt in ["torchscript" , "ts" ] ])
67
+ ir_targets_torchscript = any (ir == opt for opt in ["torchscript" , "ts" ])
65
68
ir_targets_fx = ir == "fx"
66
69
ir_targets_dynamo = ir == "dynamo"
67
70
ir_targets_torch_compile = ir == "torch_compile"
@@ -97,8 +100,8 @@ def _get_target_ir(module_type: _ModuleType, ir: str) -> _IRType:
97
100
def compile (
98
101
module : Any ,
99
102
ir : str = "default" ,
100
- inputs : List [Input | torch .Tensor | InputTensorSpec ] = [] ,
101
- enabled_precisions : Set [torch .dtype | dtype ] = set ([ torch . float ]) ,
103
+ inputs : Optional [ List [Input | torch .Tensor | InputTensorSpec ]] = None ,
104
+ enabled_precisions : Optional [ Set [torch .dtype | dtype ]] = None ,
102
105
** kwargs : Any ,
103
106
) -> (
104
107
torch .nn .Module | torch .jit .ScriptModule | torch .fx .GraphModule | Callable [..., Any ]
@@ -138,6 +141,11 @@ def compile(
138
141
Returns:
139
142
torch.nn.Module: Compiled Module, when run it will execute via TensorRT
140
143
"""
144
+ input_list = inputs if inputs is not None else []
145
+ enabled_precisions_set = (
146
+ enabled_precisions if enabled_precisions is not None else {torch .float }
147
+ )
148
+
141
149
module_type = _parse_module_type (module )
142
150
target_ir = _get_target_ir (module_type , ir )
143
151
if target_ir == _IRType .ts :
@@ -148,45 +156,50 @@ def compile(
148
156
"Module was provided as a torch.nn.Module, trying to script the module with torch.jit.script. In the event of a failure please preconvert your module to TorchScript" ,
149
157
)
150
158
ts_mod = torch .jit .script (module )
151
- assert _non_fx_input_interface (inputs )
152
- compiled_ts_module : torch .jit .ScriptModule = torch_tensorrt .ts .compile (
153
- ts_mod , inputs = inputs , enabled_precisions = enabled_precisions , ** kwargs
159
+ assert _non_fx_input_interface (input_list )
160
+ compiled_ts_module : torch .jit .ScriptModule = torchscript_compile (
161
+ ts_mod ,
162
+ inputs = input_list ,
163
+ enabled_precisions = enabled_precisions_set ,
164
+ ** kwargs ,
154
165
)
155
166
return compiled_ts_module
156
167
elif target_ir == _IRType .fx :
157
168
if (
158
- torch .float16 in enabled_precisions
159
- or torch_tensorrt .dtype .half in enabled_precisions
169
+ torch .float16 in enabled_precisions_set
170
+ or torch_tensorrt .dtype .half in enabled_precisions_set
160
171
):
161
172
lower_precision = LowerPrecision .FP16
162
173
elif (
163
- torch .float32 in enabled_precisions
164
- or torch_tensorrt .dtype .float in enabled_precisions
174
+ torch .float32 in enabled_precisions_set
175
+ or torch_tensorrt .dtype .float in enabled_precisions_set
165
176
):
166
177
lower_precision = LowerPrecision .FP32
167
178
else :
168
- raise ValueError (f"Precision { enabled_precisions } not supported on FX" )
179
+ raise ValueError (f"Precision { enabled_precisions_set } not supported on FX" )
169
180
170
- assert _fx_input_interface (inputs )
171
- compiled_fx_module : torch .nn .Module = torch_tensorrt . fx . compile (
181
+ assert _fx_input_interface (input_list )
182
+ compiled_fx_module : torch .nn .Module = fx_compile (
172
183
module ,
173
- inputs ,
184
+ input_list ,
174
185
lower_precision = lower_precision ,
175
186
explicit_batch_dimension = True ,
176
187
dynamic_batch = False ,
177
188
** kwargs ,
178
189
)
179
190
return compiled_fx_module
180
191
elif target_ir == _IRType .dynamo :
181
- compiled_aten_module : torch .fx .GraphModule = torch_tensorrt . dynamo . compile (
192
+ compiled_aten_module : torch .fx .GraphModule = dynamo_compile (
182
193
module ,
183
- inputs = inputs ,
184
- enabled_precisions = enabled_precisions ,
194
+ inputs = input_list ,
195
+ enabled_precisions = enabled_precisions_set ,
185
196
** kwargs ,
186
197
)
187
198
return compiled_aten_module
188
199
elif target_ir == _IRType .torch_compile :
189
- return torch_compile (module , enabled_precisions = enabled_precisions , ** kwargs )
200
+ return torch_compile (
201
+ module , enabled_precisions = enabled_precisions_set , ** kwargs
202
+ )
190
203
else :
191
204
raise RuntimeError ("Module is an unknown format or the ir requested is unknown" )
192
205
@@ -206,10 +219,10 @@ def torch_compile(module: torch.nn.Module, **kwargs: Any) -> Callable[..., Any]:
206
219
207
220
def convert_method_to_trt_engine (
208
221
module : Any ,
222
+ inputs : List [Input | torch .Tensor ],
209
223
method_name : str ,
210
224
ir : str = "default" ,
211
- inputs : List [Input | torch .Tensor ] = [],
212
- enabled_precisions : Set [torch .dtype | dtype ] = set ([torch .float ]),
225
+ enabled_precisions : Optional [Set [torch .dtype | dtype ]] = None ,
213
226
** kwargs : Any ,
214
227
) -> bytes :
215
228
"""Convert a TorchScript module method to a serialized TensorRT engine
@@ -242,6 +255,10 @@ def convert_method_to_trt_engine(
242
255
Returns:
243
256
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
244
257
"""
258
+ enabled_precisions_set = (
259
+ enabled_precisions if enabled_precisions is not None else {torch .float }
260
+ )
261
+
245
262
module_type = _parse_module_type (module )
246
263
target_ir = _get_target_ir (module_type , ir )
247
264
if target_ir == _IRType .ts :
@@ -254,9 +271,9 @@ def convert_method_to_trt_engine(
254
271
ts_mod = torch .jit .script (module )
255
272
return torch_tensorrt .ts .convert_method_to_trt_engine (
256
273
ts_mod ,
257
- method_name ,
258
274
inputs = inputs ,
259
- enabled_precisions = enabled_precisions ,
275
+ method_name = method_name ,
276
+ enabled_precisions = enabled_precisions_set ,
260
277
** kwargs ,
261
278
)
262
279
elif target_ir == _IRType .fx :
0 commit comments