Skip to content

Commit f2203fa

Browse files
committed
update typing
1 parent fdc7e3c commit f2203fa

File tree

4 files changed

+24
-33
lines changed

4 files changed

+24
-33
lines changed

py/torch_tensorrt/dynamo/_compiler.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,8 @@ def cross_compile_for_windows(
9696
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
9797
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
9898
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
99-
tiling_optimization_level: Optional[int] = _defaults.TILING_OPTIMIZATION_LEVEL,
100-
l2_limit_for_tiling: Optional[int] = _defaults.L2_LIMIT_FOR_TILING,
99+
tiling_optimization_level: int = _defaults.TILING_OPTIMIZATION_LEVEL,
100+
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
101101
**kwargs: Any,
102102
) -> torch.fx.GraphModule:
103103
"""Compile an ExportedProgram module using TensorRT in Linux for Inference in Windows
@@ -171,8 +171,8 @@ def cross_compile_for_windows(
171171
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
172172
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
173173
enable_weight_streaming (bool): Enable weight streaming.
174-
tiling_optimization_level (Optional[int]): The optimization level of tiling strategies. A Higher level allows TensorRT to spend more time searching for better optimization strategy. (We currently support [0, 1, 2, 3], default is 0)
175-
l2_limit_for_tiling (Optional[int]): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
174+
tiling_optimization_level (int): The optimization level of tiling strategies. A Higher level allows TensorRT to spend more time searching for better optimization strategy. (We currently support [0, 1, 2, 3], default is 0)
175+
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
176176
**kwargs: Any,
177177
Returns:
178178
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -419,8 +419,8 @@ def compile(
419419
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
420420
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
421421
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
422-
tiling_optimization_level: Optional[int] = _defaults.TILING_OPTIMIZATION_LEVEL,
423-
l2_limit_for_tiling: Optional[int] = _defaults.L2_LIMIT_FOR_TILING,
422+
tiling_optimization_level: int = _defaults.TILING_OPTIMIZATION_LEVEL,
423+
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
424424
**kwargs: Any,
425425
) -> torch.fx.GraphModule:
426426
"""Compile an ExportedProgram module for NVIDIA GPUs using TensorRT
@@ -496,8 +496,8 @@ def compile(
496496
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
497497
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
498498
enable_weight_streaming (bool): Enable weight streaming.
499-
tiling_optimization_level (Optional[int]): The optimization level of tiling strategies. A Higher level allows TensorRT to spend more time searching for better optimization strategy. (We currently support [0, 1, 2, 3], default is 0)
500-
l2_limit_for_tiling (Optional[int]): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
499+
tiling_optimization_level (int): The optimization level of tiling strategies. A Higher level allows TensorRT to spend more time searching for better optimization strategy. (We currently support [0, 1, 2, 3], default is 0)
500+
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
501501
**kwargs: Any,
502502
Returns:
503503
torch.fx.GraphModule: Compiled FX Module, when run it will execute via TensorRT
@@ -962,8 +962,8 @@ def convert_exported_program_to_serialized_trt_engine(
962962
strip_engine_weights: bool = _defaults.STRIP_ENGINE_WEIGHTS,
963963
immutable_weights: bool = _defaults.IMMUTABLE_WEIGHTS,
964964
enable_weight_streaming: bool = _defaults.ENABLE_WEIGHT_STREAMING,
965-
tiling_optimization_level: Optional[int] = _defaults.TILING_OPTIMIZATION_LEVEL,
966-
l2_limit_for_tiling: Optional[int] = _defaults.L2_LIMIT_FOR_TILING,
965+
tiling_optimization_level: int = _defaults.TILING_OPTIMIZATION_LEVEL,
966+
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
967967
**kwargs: Any,
968968
) -> bytes:
969969
"""Convert an ExportedProgram to a serialized TensorRT engine
@@ -1027,8 +1027,8 @@ def convert_exported_program_to_serialized_trt_engine(
10271027
strip_engine_weights (bool): Strip engine weights from the serialized engine. This is useful when the engine is to be deployed in an environment where the weights are not required.
10281028
immutable_weights (bool): Build non-refittable engines. This is useful for some layers that are not refittable. If this argument is set to true, `strip_engine_weights` and `refit_identical_engine_weights` will be ignored.
10291029
enable_weight_streaming (bool): Enable weight streaming.
1030-
tiling_optimization_level (Optional[int]): The optimization level of tiling strategies. A Higher level allows TensorRT to spend more time searching for better optimization strategy. (We currently support [0, 1, 2, 3], default is 0)
1031-
l2_limit_for_tiling (Optional[int]): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
1030+
tiling_optimization_level (int): The optimization level of tiling strategies. A Higher level allows TensorRT to spend more time searching for better optimization strategy. (We currently support [0, 1, 2, 3], default is 0)
1031+
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
10321032
Returns:
10331033
bytes: Serialized TensorRT engine, can either be saved to a file or deserialized via TensorRT APIs
10341034
"""

py/torch_tensorrt/dynamo/_settings.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,8 @@ class CompilationSettings:
9595
enable_cross_compile_for_windows (bool): By default this is False means TensorRT engines can only be executed on the same platform where they were built.
9696
True will enable cross-platform compatibility which allows the engine to be built on Linux and run on Windows
9797
use_aot_joint_export (bool): Use aot_export_joint_simple, else wrap backend with AOT_autograd, required for distributed tensors
98-
tiling_optimization_level (Optional[int]): The optimization level of tiling strategies. A Higher level allows TensorRT to spend more time searching for better optimization strategy. (We currently support [0, 1, 2, 3], default is 0)
99-
l2_limit_for_tiling (Optional[int]): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
98+
tiling_optimization_level (int): The optimization level of tiling strategies. A Higher level allows TensorRT to spend more time searching for better optimization strategy. (We currently support [0, 1, 2, 3], default is 0)
99+
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
100100
"""
101101

102102
enabled_precisions: Set[dtype] = field(default_factory=lambda: ENABLED_PRECISIONS)
@@ -138,8 +138,8 @@ class CompilationSettings:
138138
enable_weight_streaming: bool = ENABLE_WEIGHT_STREAMING
139139
enable_cross_compile_for_windows: bool = ENABLE_CROSS_COMPILE_FOR_WINDOWS
140140
use_aot_joint_export: bool = USE_AOT_JOINT_EXPORT
141-
tiling_optimization_level: Optional[int] = TILING_OPTIMIZATION_LEVEL
142-
l2_limit_for_tiling: Optional[int] = L2_LIMIT_FOR_TILING
141+
tiling_optimization_level: int = TILING_OPTIMIZATION_LEVEL
142+
l2_limit_for_tiling: int = L2_LIMIT_FOR_TILING
143143

144144

145145
_SETTINGS_TO_BE_ENGINE_INVARIANT = (

py/torch_tensorrt/dynamo/conversion/_TRTInterpreter.py

+4-13
Original file line numberDiff line numberDiff line change
@@ -330,10 +330,7 @@ def _populate_trt_builder_config(
330330
builder_config.set_flag(trt.BuilderFlag.WEIGHT_STREAMING)
331331

332332
if version.parse(trt.__version__) >= version.parse("10.8"):
333-
if (
334-
self.compilation_settings.tiling_optimization_level is None
335-
or self.compilation_settings.tiling_optimization_level == 0
336-
):
333+
if self.compilation_settings.tiling_optimization_level == 0:
337334
builder_config.tiling_optimization_level = (
338335
trt.TilingOptimizationLevel.NONE
339336
)
@@ -354,15 +351,9 @@ def _populate_trt_builder_config(
354351
f"Invalid tiling optimization level: {self.compilation_settings.tiling_optimization_level}. A valid value should be in [0, 1, 2, 3]."
355352
)
356353

357-
if (
358-
self.compilation_settings.l2_limit_for_tiling is None
359-
or self.compilation_settings.l2_limit_for_tiling == -1
360-
):
361-
builder_config.l2_limit_for_tiling = -1
362-
else:
363-
builder_config.l2_limit_for_tiling = (
364-
self.compilation_settings.l2_limit_for_tiling
365-
)
354+
builder_config.l2_limit_for_tiling = (
355+
self.compilation_settings.l2_limit_for_tiling
356+
)
366357

367358
return builder_config
368359

py/torch_tensorrt/dynamo/runtime/_MutableTorchTensorRTModule.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,8 @@ def __init__(
9292
dryrun: bool = _defaults.DRYRUN,
9393
hardware_compatible: bool = _defaults.HARDWARE_COMPATIBLE,
9494
timing_cache_path: str = _defaults.TIMING_CACHE_PATH,
95-
tiling_optimization_level: Optional[int] = _defaults.TILING_OPTIMIZATION_LEVEL,
96-
l2_limit_for_tiling: Optional[int] = _defaults.L2_LIMIT_FOR_TILING,
95+
tiling_optimization_level: int = _defaults.TILING_OPTIMIZATION_LEVEL,
96+
l2_limit_for_tiling: int = _defaults.L2_LIMIT_FOR_TILING,
9797
**kwargs: Any,
9898
) -> None:
9999
"""
@@ -135,8 +135,8 @@ def __init__(
135135
hardware_compatible (bool): Build the TensorRT engines compatible with GPU architectures other than that of the GPU on which the engine was built (currently works for NVIDIA Ampere and newer)
136136
timing_cache_path (str): Path to the timing cache if it exists (or) where it will be saved after compilation
137137
lazy_engine_init (bool): Defer setting up engines until the compilation of all engines is complete. Can allow larger models with multiple graph breaks to compile but can lead to oversubscription of GPU memory at runtime.
138-
tiling_optimization_level (Optional[int]): The optimization level of tiling strategies. A Higher level allows TensorRT to spend more time searching for better optimization strategy. (We currently support [0, 1, 2, 3], default is 0)
139-
l2_limit_for_tiling (Optional[int]): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
138+
tiling_optimization_level (int): The optimization level of tiling strategies. A Higher level allows TensorRT to spend more time searching for better optimization strategy. (We currently support [0, 1, 2, 3], default is 0)
139+
l2_limit_for_tiling (int): The target L2 cache usage limit (in bytes) for tiling optimization (default is -1 which means no limit).
140140
**kwargs: Any,
141141
Returns:
142142
MutableTorchTensorRTModule

0 commit comments

Comments
 (0)