@@ -48,7 +48,11 @@ def compile(
48
48
save_timing_cache = False ,
49
49
cuda_graph_batch_size = - 1 ,
50
50
is_aten = False ,
51
+ explicit_batch_dimension = True ,
51
52
use_experimental_fx_rt = False ,
53
+ max_aux_streams = None ,
54
+ version_compatible = False ,
55
+ optimization_level = None ,
52
56
num_avg_timing_iters = 1 ,
53
57
torch_executed_ops = [],
54
58
torch_executed_modules = [],
@@ -67,7 +71,11 @@ def compile(
67
71
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
68
72
save_timing_cache: Update timing cache with current timing cache data if set to True.
69
73
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
74
+ explicit_batch_dimension: Whether to specify an explicit batch dimension to TRT
70
75
use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
76
+ max_aux_streams: max number of aux stream to use
77
+ version_compatible: enable version compatible feature
78
+ optimization_level: builder optimization level
71
79
Returns:
72
80
A torch.nn.Module lowered by TensorRT.
73
81
"""
@@ -122,7 +130,11 @@ def compile(
122
130
save_timing_cache = save_timing_cache ,
123
131
cuda_graph_batch_size = cuda_graph_batch_size ,
124
132
is_aten = is_aten ,
133
+ explicit_batch_dimension = explicit_batch_dimension ,
125
134
use_experimental_rt = use_experimental_fx_rt ,
135
+ max_aux_streams = max_aux_streams ,
136
+ version_compatible = version_compatible ,
137
+ optimization_level = optimization_level ,
126
138
)
127
139
lowerer = Lowerer .create (lower_setting = lower_setting )
128
140
return lowerer (module , inputs )
0 commit comments