4
4
import torch_tensorrt
5
5
from functools import partial
6
6
7
- from typing import Any , Sequence
7
+ from typing import Any , Optional , Sequence
8
8
from torch_tensorrt import EngineCapability , Device
9
9
from torch_tensorrt .fx .utils import LowerPrecision
10
10
14
14
from torch_tensorrt .dynamo .backend ._defaults import (
15
15
PRECISION ,
16
16
DEBUG ,
17
- MAX_WORKSPACE_SIZE ,
17
+ WORKSPACE_SIZE ,
18
18
MIN_BLOCK_SIZE ,
19
19
PASS_THROUGH_BUILD_FAILURES ,
20
+ MAX_AUX_STREAMS ,
21
+ VERSION_COMPATIBLE ,
22
+ OPTIMIZATION_LEVEL ,
23
+ USE_EXPERIMENTAL_RT ,
20
24
)
21
25
22
26
@@ -35,7 +39,7 @@ def compile(
35
39
debug = DEBUG ,
36
40
capability = EngineCapability .default ,
37
41
num_avg_timing_iters = 1 ,
38
- workspace_size = MAX_WORKSPACE_SIZE ,
42
+ workspace_size = WORKSPACE_SIZE ,
39
43
dla_sram_size = 1048576 ,
40
44
dla_local_dram_size = 1073741824 ,
41
45
dla_global_dram_size = 536870912 ,
@@ -45,6 +49,10 @@ def compile(
45
49
min_block_size = MIN_BLOCK_SIZE ,
46
50
torch_executed_ops = [],
47
51
torch_executed_modules = [],
52
+ max_aux_streams = MAX_AUX_STREAMS ,
53
+ version_compatible = VERSION_COMPATIBLE ,
54
+ optimization_level = OPTIMIZATION_LEVEL ,
55
+ use_experimental_rt = USE_EXPERIMENTAL_RT ,
48
56
** kwargs ,
49
57
):
50
58
if debug :
@@ -86,6 +94,10 @@ def compile(
86
94
workspace_size = workspace_size ,
87
95
min_block_size = min_block_size ,
88
96
torch_executed_ops = torch_executed_ops ,
97
+ max_aux_streams = max_aux_streams ,
98
+ version_compatible = version_compatible ,
99
+ optimization_level = optimization_level ,
100
+ use_experimental_rt = use_experimental_rt ,
89
101
** kwargs ,
90
102
)
91
103
@@ -105,19 +117,30 @@ def compile(
105
117
def create_backend (
106
118
precision : LowerPrecision = PRECISION ,
107
119
debug : bool = DEBUG ,
108
- workspace_size : int = MAX_WORKSPACE_SIZE ,
120
+ workspace_size : int = WORKSPACE_SIZE ,
109
121
min_block_size : int = MIN_BLOCK_SIZE ,
110
122
torch_executed_ops : Sequence [str ] = set (),
111
123
pass_through_build_failures : bool = PASS_THROUGH_BUILD_FAILURES ,
124
+ max_aux_streams : Optional [int ] = MAX_AUX_STREAMS ,
125
+ version_compatible : bool = VERSION_COMPATIBLE ,
126
+ optimization_level : Optional [int ] = OPTIMIZATION_LEVEL ,
127
+ use_experimental_rt : bool = USE_EXPERIMENTAL_RT ,
112
128
** kwargs ,
113
129
):
114
130
"""Create torch.compile backend given specified arguments
115
131
116
132
Args:
117
133
precision:
118
134
debug: Whether to print out verbose debugging information
119
- workspace_size: Maximum workspace TRT is allowed to use for the module
120
- precision: Model Layer precision
135
+ workspace_size: Workspace TRT is allowed to use for the module (0 is default)
136
+ min_block_size: Minimum number of operators per TRT-Engine Block
137
+ torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
138
+ pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
139
+ max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
140
+ version_compatible: Provide version forward-compatibility for engine plan files
141
+ optimization_level: Builder optimization 0-5, higher levels imply longer build time,
142
+ searching for more optimization options. TRT defaults to 3
143
+ use_experimental_rt: Whether to use the new experimental TRTModuleNext for TRT engines
121
144
Returns:
122
145
Backend for torch.compile
123
146
"""
@@ -131,6 +154,10 @@ def create_backend(
131
154
min_block_size = min_block_size ,
132
155
torch_executed_ops = torch_executed_ops ,
133
156
pass_through_build_failures = pass_through_build_failures ,
157
+ max_aux_streams = max_aux_streams ,
158
+ version_compatible = version_compatible ,
159
+ optimization_level = optimization_level ,
160
+ use_experimental_rt = use_experimental_rt ,
134
161
)
135
162
136
163
return partial (
0 commit comments