Skip to content

Commit 6553035

Browse files
committed
refactor: Modify prepare_inputs, remove lower_precision
Signed-off-by: Dheeraj Peri <[email protected]> chore: refactor Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 6c3b3f7 commit 6553035

File tree

7 files changed

+69
-130
lines changed

7 files changed

+69
-130
lines changed

py/torch_tensorrt/dynamo/_defaults.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
from torch_tensorrt.fx.utils import LowerPrecision
1+
import torch
22

3-
4-
PRECISION = LowerPrecision.FP32
3+
PRECISION = torch.float32
54
DEBUG = False
65
WORKSPACE_SIZE = 0
76
MIN_BLOCK_SIZE = 5

py/torch_tensorrt/dynamo/_settings.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from dataclasses import dataclass, field
22
from typing import Optional, Sequence
3-
4-
from torch_tensorrt.fx.utils import LowerPrecision
3+
import torch
54
from torch_tensorrt.dynamo._defaults import (
65
PRECISION,
76
DEBUG,
@@ -17,7 +16,7 @@
1716

1817
@dataclass
1918
class CompilationSettings:
20-
precision: LowerPrecision = PRECISION
19+
precision: torch.dtype = PRECISION
2120
debug: bool = DEBUG
2221
workspace_size: int = WORKSPACE_SIZE
2322
min_block_size: int = MIN_BLOCK_SIZE
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
from .backends import torch_tensorrt_backend
2-
from .compile import compile

py/torch_tensorrt/dynamo/compile.py

+30-87
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from typing import Any, Optional, Sequence
88
from torch_tensorrt import EngineCapability, Device
9-
from torch_tensorrt.fx.utils import LowerPrecision
109
from torch.fx.passes.pass_manager import PassManager
1110
from torch.fx.passes.shape_prop import ShapeProp
1211
from torch_tensorrt.dynamo.aten_tracer import trace
@@ -78,119 +77,63 @@ def compile(
7877
if not isinstance(inputs, collections.abc.Sequence):
7978
inputs = [inputs]
8079

81-
inputs = prepare_inputs(inputs, prepare_device(device))
80+
torchtrt_inputs, torch_inputs = prepare_inputs(inputs, prepare_device(device))
8281

8382
if (
8483
torch.float16 in enabled_precisions
8584
or torch_tensorrt.dtype.half in enabled_precisions
8685
):
87-
lower_precision = LowerPrecision.FP16
86+
precision = torch.float16
8887
elif (
8988
torch.float32 in enabled_precisions
9089
or torch_tensorrt.dtype.float in enabled_precisions
9190
):
92-
lower_precision = LowerPrecision.FP32
91+
precision = torch.float32
9392
elif len(enabled_precisions) == 0:
9493
logger.info(f"No precision specified, defaulting to {PRECISION}")
95-
lower_precision = PRECISION
94+
precision = PRECISION
9695
else:
9796
raise ValueError(
9897
f"Precision {enabled_precisions} not supported in the Dynamo Path"
9998
)
10099

100+
compilation_options = {
101+
"precision": precision,
102+
"debug": debug,
103+
"workspace_size": workspace_size,
104+
"min_block_size": min_block_size,
105+
"torch_executed_ops": torch_executed_ops,
106+
"pass_through_build_failures": pass_through_build_failures,
107+
"max_aux_streams": max_aux_streams,
108+
"version_compatible": version_compatible,
109+
"optimization_level": optimization_level,
110+
"use_python_runtime": use_python_runtime,
111+
}
112+
101113
if kwargs.get("ir", "dynamo") == "torch_compile":
102-
custom_backend = create_backend(
103-
precision=lower_precision,
104-
debug=debug,
105-
workspace_size=workspace_size,
106-
min_block_size=min_block_size,
107-
torch_executed_ops=torch_executed_ops,
108-
pass_through_build_failures=pass_through_build_failures,
109-
max_aux_streams=max_aux_streams,
110-
version_compatible=version_compatible,
111-
optimization_level=optimization_level,
112-
use_python_runtime=use_python_runtime,
113-
**kwargs,
114+
model = torch.compile(
115+
gm,
116+
backend=torch_tensorrt_backend,
117+
options={**compilation_options, **kwargs},
114118
)
115-
model = torch.compile(gm, backend=custom_backend)
116119
# Ensure compilation occurs by calling the function with provided inputs
117-
model(*inputs)
120+
model(*torch_inputs)
118121
return model
119122

120123
else:
121-
settings = CompilationSettings(
122-
debug=debug,
123-
precision=lower_precision,
124-
workspace_size=workspace_size,
125-
min_block_size=min_block_size,
126-
torch_executed_ops=torch_executed_ops,
127-
pass_through_build_failures=pass_through_build_failures,
128-
max_aux_streams=max_aux_streams,
129-
version_compatible=version_compatible,
130-
optimization_level=optimization_level,
131-
use_python_runtime=use_python_runtime,
132-
)
133-
134-
model = trace(gm, inputs, **kwargs)
124+
settings = CompilationSettings(**compilation_options)
125+
model = trace(gm, torch_inputs, **kwargs)
135126

136127
if kwargs.get("use_capability_partitioner", None):
137-
model = lower_model(model, inputs)
138-
return _compile_module(model, inputs, settings)
128+
model = lower_model(model, torch_inputs)
129+
return _compile_module(model, torch_inputs, settings)
139130
else:
140-
split_result = lower_model_using_trt_splitter(model, inputs)
141-
trt_module = _compile_graph(split_result, inputs, settings)
131+
split_result = lower_model_using_trt_splitter(model, torch_inputs)
132+
trt_module = _compile_graph(split_result, torch_inputs, settings)
142133

143134
return trt_module
144135

145136

146-
def create_backend(
147-
precision: LowerPrecision = PRECISION,
148-
debug: bool = DEBUG,
149-
workspace_size: int = WORKSPACE_SIZE,
150-
min_block_size: int = MIN_BLOCK_SIZE,
151-
torch_executed_ops: Sequence[str] = set(),
152-
pass_through_build_failures: bool = PASS_THROUGH_BUILD_FAILURES,
153-
max_aux_streams: Optional[int] = MAX_AUX_STREAMS,
154-
version_compatible: bool = VERSION_COMPATIBLE,
155-
optimization_level: Optional[int] = OPTIMIZATION_LEVEL,
156-
use_python_runtime: Optional[bool] = USE_PYTHON_RUNTIME,
157-
**kwargs,
158-
):
159-
"""Create torch.compile backend given specified arguments
160-
161-
Args:
162-
precision: Model Layer precision
163-
debug: Whether to print out verbose debugging information
164-
workspace_size: Workspace TRT is allowed to use for the module (0 is default)
165-
min_block_size: Minimum number of operators per TRT-Engine Block
166-
torch_executed_ops: Sequence of operations to run in Torch, regardless of converter coverage
167-
pass_through_build_failures: Whether to fail on TRT engine build errors (True) or not (False)
168-
max_aux_streams: Maximum number of allowed auxiliary TRT streams for each engine
169-
version_compatible: Provide version forward-compatibility for engine plan files
170-
optimization_level: Builder optimization 0-5, higher levels imply longer build time,
171-
searching for more optimization options. TRT defaults to 3
172-
use_python_runtime: Whether to strictly use Python runtime or C++ runtime. To auto-select a runtime
173-
based on C++ dependency presence (preferentially choosing C++ runtime if available), leave the
174-
argument as None
175-
Returns:
176-
Backend for torch.compile
177-
"""
178-
return partial(
179-
torch_tensorrt_backend,
180-
debug=debug,
181-
precision=precision,
182-
workspace_size=workspace_size,
183-
min_block_size=min_block_size,
184-
torch_executed_ops=torch_executed_ops,
185-
pass_through_build_failures=pass_through_build_failures,
186-
max_aux_streams=max_aux_streams,
187-
version_compatible=version_compatible,
188-
optimization_level=optimization_level,
189-
use_python_runtime=use_python_runtime,
190-
**kwargs,
191-
)
192-
193-
194137
def _compile_graph(
195138
split_result: TRTSplitter,
196139
inputs: Any,
@@ -234,7 +177,7 @@ def lower_model(model: torch.nn.Module, inputs: Any, **kwargs):
234177
[fuse_permute_matmul, fuse_permute_linear]
235178
)
236179
lowered_model = graph_optimization_pm(model)
237-
if isinstance(lowered_model, torch.fx.GraphModule):
238-
ShapeProp(lowered_model).propagate(*inputs)
180+
# if isinstance(lowered_model, torch.fx.GraphModule):
181+
# ShapeProp(lowered_model).propagate(*inputs)
239182

240183
return lowered_model

py/torch_tensorrt/dynamo/conversion/conversion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def convert_module(
4141
)
4242
interpreter_result = interpreter.run(
4343
workspace_size=settings.workspace_size,
44-
lower_precision=settings.precision,
44+
precision=settings.precision,
4545
profiling_verbosity=(
4646
trt.ProfilingVerbosity.VERBOSE
4747
if settings.debug

py/torch_tensorrt/dynamo/conversion/trt_interpreter.py

+8-17
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
from torch_tensorrt.fx.observer import Observer
2020
from torch_tensorrt.fx.utils import (
2121
get_dynamic_dims,
22-
LowerPrecision,
2322
unified_dtype_converter,
2423
Frameworks,
2524
)
@@ -98,7 +97,7 @@ def validate_conversion(self):
9897
def run(
9998
self,
10099
workspace_size=0,
101-
lower_precision=LowerPrecision.FP16,
100+
precision=torch.float32,
102101
sparse_weights=False,
103102
disable_tf32=False,
104103
force_fp32_output=False,
@@ -115,7 +114,7 @@ def run(
115114
Build TensorRT engine with some configs.
116115
Args:
117116
workspace_size: Amount of memory used by TensorRT to store intermediate buffers within an operation.
118-
lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
117+
precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision).
119118
sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity
120119
force_fp32_output: force output to be fp32
121120
strict_type_constraints: Usually we should set it to False unless we want to control the precision of certain layer for numeric reasons.
@@ -131,22 +130,14 @@ def run(
131130
"""
132131
TRT_INTERPRETER_CALL_PRE_OBSERVER.observe(self.module)
133132

134-
# For float outputs, we set their dtype to fp16 only if lower_precision == LowerPrecision.FP16 and
133+
# For float outputs, we set their dtype to fp16 only if precision == torch.float16 and
135134
# force_fp32_output=False. Overriden by specifying output_dtypes
136-
self.output_fp16 = (
137-
not force_fp32_output and lower_precision == LowerPrecision.FP16
138-
)
135+
self.output_fp16 = not force_fp32_output and precision == torch.float16
139136

140-
if (
141-
lower_precision == LowerPrecision.INT8
142-
and not self.builder.platform_has_fast_int8
143-
):
137+
if precision == torch.int8 and not self.builder.platform_has_fast_int8:
144138
raise RuntimeError("Current platform doesn't support fast native int8!")
145139

146-
if (
147-
lower_precision == LowerPrecision.FP16
148-
and not self.builder.platform_has_fast_fp16
149-
):
140+
if precision == torch.float16 and not self.builder.platform_has_fast_fp16:
150141
warnings.warn("Current platform doesn't support fast native fp16!")
151142

152143
self.input_specs_iter = 0
@@ -190,10 +181,10 @@ def run(
190181
_LOGGER.info(f"Using optimization level {optimization_level}")
191182
builder_config.builder_optimization_level = optimization_level
192183

193-
if lower_precision == LowerPrecision.FP16:
184+
if precision == torch.float16:
194185
builder_config.set_flag(trt.BuilderFlag.FP16)
195186

196-
if lower_precision == LowerPrecision.INT8:
187+
if precision == torch.int8:
197188
builder_config.set_flag(trt.BuilderFlag.INT8)
198189

199190
if sparse_weights:

py/torch_tensorrt/dynamo/utils.py

+26-18
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from dataclasses import replace, fields
44
from torch_tensorrt.dynamo import CompilationSettings
55
from typing import Any, Union, Sequence, Dict
6-
from torch_tensorrt import _Input, Device
6+
from torch_tensorrt import Input, Device
77
from typing import Optional
88

99
logger = logging.getLogger(__name__)
@@ -55,43 +55,51 @@ def cosine_similarity(gt_tensor, pred_tensor):
5555

5656

5757
def prepare_inputs(
58-
inputs: Union[_Input.Input, torch.Tensor, Sequence, Dict],
58+
inputs: Union[Input, torch.Tensor, Sequence, Dict],
5959
device: torch.device = torch.device("cuda"),
6060
) -> Any:
61-
if isinstance(inputs, _Input.Input):
61+
if isinstance(inputs, Input):
6262
if isinstance(inputs.shape, dict):
63-
return inputs.example_tensor(optimization_profile_field="opt_shape").to(
64-
device
65-
)
63+
return inputs, inputs.example_tensor(
64+
optimization_profile_field="opt_shape"
65+
).to(device)
6666
else:
67-
return inputs.example_tensor().to(device)
67+
return inputs, inputs.example_tensor().to(device)
6868

6969
elif isinstance(inputs, torch.Tensor):
70-
return inputs
70+
return Input.from_tensor(inputs), inputs
7171

7272
elif isinstance(inputs, list):
7373
prepared_input = list()
74-
74+
torchtrt_inputs = []
75+
torch_inputs = []
7576
for input_obj in inputs:
76-
prepared_input.append(prepare_inputs(input_obj))
77+
torchtrt_input, torch_input = prepare_inputs(input_obj)
78+
torchtrt_inputs.append(torchtrt_input)
79+
torch_inputs.append(torch_input)
7780

78-
return prepared_input
81+
return torchtrt_inputs, torch_inputs
7982

8083
elif isinstance(inputs, tuple):
81-
prepared_input = list()
82-
84+
torchtrt_inputs = []
85+
torch_inputs = []
8386
for input_obj in inputs:
84-
prepared_input.append(prepare_inputs(input_obj))
87+
torchtrt_input, torch_input = prepare_inputs(input_obj)
88+
torchtrt_inputs.append(torchtrt_input)
89+
torch_inputs.append(torch_input)
8590

86-
return tuple(prepared_input)
91+
return tuple(torchtrt_inputs), tuple(torch_inputs)
8792

8893
elif isinstance(inputs, dict):
89-
prepared_input = dict()
94+
torchtrt_inputs = dict()
95+
torch_inputs = dict()
9096

9197
for key, input_obj in inputs.items():
92-
prepared_input[key] = prepare_inputs(input_obj)
98+
torchtrt_input, torch_input = prepare_inputs(input_obj)
99+
torchtrt_inputs[key] = torchtrt_input
100+
torch_inputs[key] = torch_input
93101

94-
return prepared_input
102+
return torchtrt_inputs, torch_inputs
95103

96104
else:
97105
raise ValueError(

0 commit comments

Comments
 (0)