|
6 | 6 |
|
7 | 7 | from typing import Any, Optional, Sequence
|
8 | 8 | from torch_tensorrt import EngineCapability, Device
|
9 |
| -from torch_tensorrt.fx.utils import LowerPrecision |
10 | 9 | from torch.fx.passes.pass_manager import PassManager
|
11 | 10 | from torch.fx.passes.shape_prop import ShapeProp
|
12 | 11 | from torch_tensorrt.dynamo.aten_tracer import trace
|
@@ -78,119 +77,63 @@ def compile(
|
78 | 77 | if not isinstance(inputs, collections.abc.Sequence):
|
79 | 78 | inputs = [inputs]
|
80 | 79 |
|
81 |
| - inputs = prepare_inputs(inputs, prepare_device(device)) |
| 80 | + torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device)) |
82 | 81 |
|
83 | 82 | if (
|
84 | 83 | torch.float16 in enabled_precisions
|
85 | 84 | or torch_tensorrt.dtype.half in enabled_precisions
|
86 | 85 | ):
|
87 |
| - lower_precision = LowerPrecision.FP16 |
| 86 | + precision = torch.float16 |
88 | 87 | elif (
|
89 | 88 | torch.float32 in enabled_precisions
|
90 | 89 | or torch_tensorrt.dtype.float in enabled_precisions
|
91 | 90 | ):
|
92 |
| - lower_precision = LowerPrecision.FP32 |
| 91 | + precision = torch.float32 |
93 | 92 | elif len(enabled_precisions) == 0:
|
94 | 93 | logger.info(f"No precision specified, defaulting to {PRECISION}")
|
95 |
| - lower_precision = PRECISION |
| 94 | + precision = PRECISION |
96 | 95 | else:
|
97 | 96 | raise ValueError(
|
98 | 97 | f"Precision {enabled_precisions} not supported in the Dynamo Path"
|
99 | 98 | )
|
100 | 99 |
|
| 100 | + compilation_options = { |
| 101 | + "precision": precision, |
| 102 | + "debug": debug, |
| 103 | + "workspace_size": workspace_size, |
| 104 | + "min_block_size": min_block_size, |
| 105 | + "torch_executed_ops": torch_executed_ops, |
| 106 | + "pass_through_build_failures": pass_through_build_failures, |
| 107 | + "max_aux_streams": max_aux_streams, |
| 108 | + "version_compatible": version_compatible, |
| 109 | + "optimization_level": optimization_level, |
| 110 | + "use_python_runtime": use_python_runtime, |
| 111 | + } |
| 112 | + |
101 | 113 | if kwargs.get("ir", "dynamo") == "torch_compile":
|
102 |
| - custom_backend = create_backend( |
103 |
| - precision=lower_precision, |
104 |
| - debug=debug, |
105 |
| - workspace_size=workspace_size, |
106 |
| - min_block_size=min_block_size, |
107 |
| - torch_executed_ops=torch_executed_ops, |
108 |
| - pass_through_build_failures=pass_through_build_failures, |
109 |
| - max_aux_streams=max_aux_streams, |
110 |
| - version_compatible=version_compatible, |
111 |
| - optimization_level=optimization_level, |
112 |
| - use_python_runtime=use_python_runtime, |
113 |
| - **kwargs, |
| 114 | + model = torch.compile( |
| 115 | + gm, |
| 116 | + backend=torch_tensorrt_backend, |
| 117 | + options={**compilation_options, **kwargs}, |
114 | 118 | )
|
115 |
| - model = torch.compile(gm, backend=custom_backend) |
116 | 119 | # Ensure compilation occurs by calling the function with provided inputs
|
117 |
| - model(*inputs) |
| 120 | + model(*torch_inputs) |
118 | 121 | return model
|
119 | 122 |
|
120 | 123 | else:
|
121 |
| - settings = CompilationSettings( |
122 |
| - debug=debug, |
123 |
| - precision=lower_precision, |
124 |
| - workspace_size=workspace_size, |
125 |
| - min_block_size=min_block_size, |
126 |
| - torch_executed_ops=torch_executed_ops, |
127 |
| - pass_through_build_failures=pass_through_build_failures, |
128 |
| - max_aux_streams=max_aux_streams, |
129 |
| - version_compatible=version_compatible, |
130 |
| - optimization_level=optimization_level, |
131 |
| - use_python_runtime=use_python_runtime, |
132 |
| - ) |
133 |
| - |
134 |
| - model = trace(gm, inputs, **kwargs) |
| 124 | + settings = CompilationSettings(**compilation_options) |
| 125 | + model = trace(gm, torch_inputs, **kwargs) |
135 | 126 |
|
136 | 127 | if kwargs.get("use_capability_partitioner", None):
|
137 |
| - model = lower_model(model, inputs) |
138 |
| - return _compile_module(model, inputs, settings) |
| 128 | + model = lower_model(model, torch_inputs) |
| 129 | + return _compile_module(model, torch_inputs, settings) |
139 | 130 | else:
|
140 |
| - split_result = lower_model_using_trt_splitter(model, inputs) |
141 |
| - trt_module = _compile_graph(split_result, inputs, settings) |
| 131 | + split_result = lower_model_using_trt_splitter(model, torch_inputs) |
| 132 | + trt_module = _compile_graph(split_result, torch_inputs, settings) |
142 | 133 |
|
143 | 134 | return trt_module
|
144 | 135 |
|
145 | 136 |
|
146 |
| -def create_backend( |
147 |
| - precision: LowerPrecision = PRECISION, |
148 |
| - debug: bool = DEBUG, |
149 |
| - workspace_size: int = WORKSPACE_SIZE, |
150 |
| - min_block_size: int = MIN_BLOCK_SIZE, |
151 |
| - torch_executed_ops: Sequence[str] = set(), |
152 |
| - pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES, |
153 |
| - max_aux_streams: Optional[int] = MAX_AUX_STREAMS, |
154 |
| - version_compatible: bool = VERSION_COMPATIBLE, |
155 |
| - optimization_level: Optional[int] = OPTIMIZATION_LEVEL, |
156 |
| - use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME, |
157 |
| - **kwargs, |
158 |
| -): |
159 |
| - """Create torch.compile backend given specified arguments |
160 |
| -
|
161 |
| - Args: |
162 |
| - precision: Model Layer precision |
163 |
| - debug: Whether to print out verbose debugging information |
164 |
| - workspace_size: Workspace TRT is allowed to use for the module (0 is default) |
165 |
| - min_block_size: Minimum number of operators per TRT-Engine Block |
166 |
| - torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage |
167 |
| - pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False) |
168 |
| - max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine |
169 |
| - version_compatible: Provide version forward-compatibility for engine plan files |
170 |
| - optimization_level: Builder optimization 0-5, higher levels imply longer build time, |
171 |
| - searching for more optimization options. TRT defaults to 3 |
172 |
| - use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime |
173 |
| - based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the |
174 |
| - argument as None |
175 |
| - Returns: |
176 |
| - Backend for torch.compile |
177 |
| - """ |
178 |
| - return partial( |
179 |
| - torch_tensorrt_backend, |
180 |
| - debug=debug, |
181 |
| - precision=precision, |
182 |
| - workspace_size=workspace_size, |
183 |
| - min_block_size=min_block_size, |
184 |
| - torch_executed_ops=torch_executed_ops, |
185 |
| - pass_through_build_failures=pass_through_build_failures, |
186 |
| - max_aux_streams=max_aux_streams, |
187 |
| - version_compatible=version_compatible, |
188 |
| - optimization_level=optimization_level, |
189 |
| - use_python_runtime=use_python_runtime, |
190 |
| - **kwargs, |
191 |
| - ) |
192 |
| - |
193 |
| - |
194 | 137 | def _compile_graph(
|
195 | 138 | split_result: TRTSplitter,
|
196 | 139 | inputs: Any,
|
@@ -234,7 +177,7 @@ def lower_model(model: torch.nn.Module, inputs: Any, **kwargs):
|
234 | 177 | [fuse_permute_matmul, fuse_permute_linear]
|
235 | 178 | )
|
236 | 179 | lowered_model = graph_optimization_pm(model)
|
237 |
| - if isinstance(lowered_model, torch.fx.GraphModule): |
238 |
| - ShapeProp(lowered_model).propagate(*inputs) |
| 180 | + # if isinstance(lowered_model, torch.fx.GraphModule): |
| 181 | + # ShapeProp(lowered_model).propagate(*inputs) |
239 | 182 |
|
240 | 183 | return lowered_model
|
0 commit comments