@@ -48,7 +48,11 @@ def compile(
48
48
save_timing_cache = False ,
49
49
cuda_graph_batch_size = - 1 ,
50
50
is_aten = False ,
51
- use_experimental_fx_rt = False ,
51
+ explicit_batch_dimension = True ,
52
+ use_experimental_rt = False ,
53
+ max_aux_streams = None ,
54
+ version_compatible = False ,
55
+ optimization_level = None ,
52
56
num_avg_timing_iters = 1 ,
53
57
torch_executed_ops = [],
54
58
torch_executed_modules = [],
@@ -67,13 +71,17 @@ def compile(
67
71
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
68
72
save_timing_cache: Update timing cache with current timing cache data if set to True.
69
73
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
70
- use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
74
+ explicit_batch_dimension: Whether to specify an explicit batch dimension to TRT
75
+ use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
76
+ max_aux_streams: max number of aux stream to use
77
+ version_compatible: enable version compatible feature
78
+ optimization_level: builder optimization level
71
79
Returns:
72
80
A torch.nn.Module lowered by TensorRT.
73
81
"""
74
- if use_experimental_fx_rt and not explicit_batch_dimension :
82
+ if use_experimental_rt and not explicit_batch_dimension :
75
83
raise ValueError (
76
- "The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_fx_rt =True"
84
+ "The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_rt =True"
77
85
)
78
86
79
87
logger .warn (
@@ -122,7 +130,11 @@ def compile(
122
130
save_timing_cache = save_timing_cache ,
123
131
cuda_graph_batch_size = cuda_graph_batch_size ,
124
132
is_aten = is_aten ,
125
- use_experimental_rt = use_experimental_fx_rt ,
133
+ explicit_batch_dimension = explicit_batch_dimension ,
134
+ use_experimental_rt = use_experimental_rt ,
135
+ max_aux_streams = max_aux_streams ,
136
+ version_compatible = version_compatible ,
137
+ optimization_level = optimization_level ,
126
138
)
127
139
lowerer = Lowerer .create (lower_setting = lower_setting )
128
140
return lowerer (module , inputs )
0 commit comments