Skip to content

Commit acbc3ff

Browse files
committed
fix: Repair argument passing in both Dynamo paths
- Pass-through new TRT args in export - Pass-through build failures arg in compile
1 parent 5b156dc commit acbc3ff

File tree

2 files changed

+19
-5
lines changed

2 files changed

+19
-5
lines changed

py/torch_tensorrt/dynamo/backend/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def compile(
4545
min_block_size=MIN_BLOCK_SIZE,
4646
torch_executed_ops=[],
4747
torch_executed_modules=[],
48+
pass_through_build_failures=PASS_THROUGH_BUILD_FAILURES,
4849
**kwargs,
4950
):
5051
if debug:
@@ -86,6 +87,7 @@ def compile(
8687
workspace_size=workspace_size,
8788
min_block_size=min_block_size,
8889
torch_executed_ops=torch_executed_ops,
90+
pass_through_build_failures=pass_through_build_failures,
8991
**kwargs,
9092
)
9193

py/torch_tensorrt/dynamo/fx_ts_compat/lower.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,11 @@ def compile(
4848
save_timing_cache=False,
4949
cuda_graph_batch_size=-1,
5050
is_aten=False,
51-
use_experimental_fx_rt=False,
51+
explicit_batch_dimension=True,
52+
use_experimental_rt=False,
53+
max_aux_streams=None,
54+
version_compatible=False,
55+
optimization_level=None,
5256
num_avg_timing_iters=1,
5357
torch_executed_ops=[],
5458
torch_executed_modules=[],
@@ -67,13 +71,17 @@ def compile(
6771
timing_cache_prefix: Timing cache file name for timing cache used by fx2trt.
6872
save_timing_cache: Update timing cache with current timing cache data if set to True.
6973
cuda_graph_batch_size: Cuda graph batch size, default to be -1.
70-
use_experimental_fx_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
74+
explicit_batch_dimension: Whether to specify an explicit batch dimension to TRT
75+
use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
76+
max_aux_streams: max number of aux stream to use
77+
version_compatible: enable version compatible feature
78+
optimization_level: builder optimization level
7179
Returns:
7280
A torch.nn.Module lowered by TensorRT.
7381
"""
74-
if use_experimental_fx_rt and not explicit_batch_dimension:
82+
if use_experimental_rt and not explicit_batch_dimension:
7583
raise ValueError(
76-
"The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_fx_rt=True"
84+
"The experimental unifed runtime only supports explicit batch. Please make sure to set explicit_batch_dimension=True when use_experimental_rt=True"
7785
)
7886

7987
logger.warn(
@@ -122,7 +130,11 @@ def compile(
122130
save_timing_cache=save_timing_cache,
123131
cuda_graph_batch_size=cuda_graph_batch_size,
124132
is_aten=is_aten,
125-
use_experimental_rt=use_experimental_fx_rt,
133+
explicit_batch_dimension=explicit_batch_dimension,
134+
use_experimental_rt=use_experimental_rt,
135+
max_aux_streams=max_aux_streams,
136+
version_compatible=version_compatible,
137+
optimization_level=optimization_level,
126138
)
127139
lowerer = Lowerer.create(lower_setting=lower_setting)
128140
return lowerer(module, inputs)

0 commit comments

Comments
 (0)