Skip to content

Commit 8b6686a

Browse files
wu6u3twnarendasan
authored andcommitted
Upstream 3 features to fx_ts_compat: MS, VC, Optimization Level (#1935)
1 parent 91299a5 commit 8b6686a

File tree

4 files changed

+73
-42
lines changed

4 files changed

+73
-42
lines changed

py/torch_tensorrt/dynamo/fx_ts_compat/fx2trt.py

+16
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ def run(
163163
timing_cache=None,
164164
profiling_verbosity=None,
165165
tactic_sources=None,
166+
max_aux_streams=None,
167+
version_compatible=False,
168+
optimization_level=None,
166169
) -> TRTInterpreterResult:
167170
"""
168171
Build TensorRT engine with some configs.
@@ -227,6 +230,18 @@ def run(
227230
if profiling_verbosity
228231
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY
229232
)
233+
234+
if trt.__version__ >= "8.6":
235+
if max_aux_streams is not None:
236+
_LOGGER.info(f"Setting max aux streams to {max_aux_streams}")
237+
builder_config.max_aux_streams = max_aux_streams
238+
if version_compatible:
239+
_LOGGER.info(f"Using version compatible")
240+
builder_config.set_flag(trt.BuilderFlag.VERSION_COMPATIBLE)
241+
if optimization_level is not None:
242+
_LOGGER.info(f"Using optimization level {optimization_level}")
243+
builder_config.builder_optimization_level = optimization_level
244+
230245
if lower_precision == LowerPrecision.FP16:
231246
builder_config.set_flag(trt.BuilderFlag.FP16)
232247

@@ -264,6 +279,7 @@ def run(
264279
_LOGGER.info(
265280
f"Build TRT engine elapsed time: {datetime.now() - build_engine_start_time}"
266281
)
282+
_LOGGER.info(f"TRT Engine uses: {engine.device_memory_size} bytes of Memory")
267283

268284
return TRTInterpreterResult(
269285
engine, self._input_names, self._output_names, serialized_cache

py/torch_tensorrt/dynamo/fx_ts_compat/lower.py

+3
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,9 @@ def __call__(self, mod, input, split_name) -> TRTInterpreterResult:
181181
if self.lower_setting.verbose_profile
182182
else trt.ProfilingVerbosity.LAYER_NAMES_ONLY,
183183
tactic_sources=self.lower_setting.tactic_sources,
184+
max_aux_streams=self.lower_setting.max_aux_streams,
185+
version_compatible=self.lower_setting.version_compatible,
186+
optimization_level=self.lower_setting.optimization_level,
184187
)
185188

186189
# Update timing cache file if needed

py/torch_tensorrt/dynamo/fx_ts_compat/lower_setting.py

+6
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,9 @@ class LowerSetting(LowerSettingBasic):
7070
correctness_atol: absolute tolerance for correctness check
7171
correctness_rtol: relative tolerance for correctness check
7272
use_experimental_rt: Uses the next generation TRTModule which supports both Python and TorchScript based execution (including in C++).
73+
max_aux_streams: max number of aux stream to use
74+
version_compatible: enable version compatible feature
75+
optimization_level: builder optimization level
7376
"""
7477

7578
input_specs: List[InputTensorSpec] = dc.field(default_factory=list)
@@ -96,3 +99,6 @@ class LowerSetting(LowerSettingBasic):
9699
correctness_atol: float = 0.1
97100
correctness_rtol: float = 0.1
98101
use_experimental_rt: bool = False
102+
max_aux_streams: Optional[int] = None
103+
version_compatible: bool = False
104+
optimization_level: Optional[int] = None

py/torch_tensorrt/dynamo/fx_ts_compat/passes/pass_utils.py

+48-42
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,10 @@ def parent_pass(module: fx.GraphModule, input: Input) -> fx.GraphModule:
126126
# (TODO(shirongwu): Add exception notification for fblearner flow when available, notify oncall
127127
# on pass that failed accuracy check.
128128
def validate_inference(
129-
rtol=None, atol=None, device=torch.device(torch.cuda.current_device())
129+
rtol=None,
130+
atol=None,
131+
device=torch.device(torch.cuda.current_device()),
132+
suppress_accuracy_check_failure=True,
130133
):
131134
def _validate_inference(pass_: PassFunc) -> PassFunc:
132135
"""
@@ -141,48 +144,51 @@ def pass_with_validation(
141144
*args,
142145
**kwargs,
143146
) -> fx.GraphModule:
144-
input_tensors = extract_example_tensors_from_input(input, device)
145-
res0 = module(*input_tensors)
146-
processed_module = pass_(module, input, *args, **kwargs)
147-
res1 = processed_module(*input_tensors)
148-
tensor_res_0 = _collect_tensors(res0)
149-
tensor_res_1 = _collect_tensors(res1)
150-
relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE
151-
152-
for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)):
153-
kwargs2 = {"equal_nan": True}
154-
if rtol:
155-
kwargs2["rtol"] = rtol
156-
if atol:
157-
kwargs2["atol"] = atol
158-
kwargs2[
159-
"msg"
160-
] = (
161-
lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
162-
)
163-
# If tensors are on different devices, make sure to compare
164-
# their copies that are on the same device.
165-
if x.get_device() != y.get_device():
166-
x = x.cpu()
167-
y = y.cpu()
168-
try:
169-
torch.testing.assert_close(x, y, **kwargs2)
170-
except Exception as e:
171-
if relax_accuracy_check_failure:
172-
_LOGGER.error(f"{e}")
173-
kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER
174-
kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER
175-
new_atol = kwargs2["atol"]
176-
new_rtol = kwargs2["rtol"]
177-
_LOGGER.info(
178-
f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}"
179-
)
147+
if suppress_accuracy_check_failure:
148+
return pass_(module, input, *args, **kwargs)
149+
else:
150+
input_tensors = extract_example_tensors_from_input(input, device)
151+
res0 = module(*input_tensors)
152+
processed_module = pass_(module, input, *args, **kwargs)
153+
res1 = processed_module(*input_tensors)
154+
tensor_res_0 = _collect_tensors(res0)
155+
tensor_res_1 = _collect_tensors(res1)
156+
relax_accuracy_check_failure = RELAX_ACCURACY_FAILURE
157+
158+
for kk, (x, y) in enumerate(zip(tensor_res_0, tensor_res_1)):
159+
kwargs2 = {"equal_nan": True}
160+
if rtol:
161+
kwargs2["rtol"] = rtol
162+
if atol:
163+
kwargs2["atol"] = atol
164+
kwargs2[
165+
"msg"
166+
] = (
167+
lambda msg: f"Pass {pass_} failed correctness check due at output {kk}:\n{msg}"
168+
)
169+
# If tensors are on different devices, make sure to compare
170+
# their copies that are on the same device.
171+
if x.get_device() != y.get_device():
172+
x = x.cpu()
173+
y = y.cpu()
174+
try:
180175
torch.testing.assert_close(x, y, **kwargs2)
181-
return processed_module
182-
else:
183-
raise e
184-
185-
return processed_module
176+
except Exception as e:
177+
if relax_accuracy_check_failure:
178+
_LOGGER.error(f"{e}")
179+
kwargs2["rtol"] *= FINAL_CHECK_RTOL_MULTIPLIER
180+
kwargs2["atol"] *= FINAL_CHECK_ATOL_MULTIPLIER
181+
new_atol = kwargs2["atol"]
182+
new_rtol = kwargs2["rtol"]
183+
_LOGGER.info(
184+
f"Do a sanity check to see whether things are completely wrong with {new_atol=}, {new_rtol=}"
185+
)
186+
torch.testing.assert_close(x, y, **kwargs2)
187+
return processed_module
188+
else:
189+
raise e
190+
191+
return processed_module
186192

187193
return pass_with_validation
188194

0 commit comments

Comments
 (0)