Skip to content

Commit ad966b7

Browse files
committed
feat: Support fallback options in trtorchc
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 75e86e8 commit ad966b7

File tree

2 files changed

+31
-1
lines changed

2 files changed

+31
-1
lines changed

Diff for: cpp/trtorchc/README.md

+5-1
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,17 @@ trtorchc [input_file_path] [output_file_path]
3737
--allow-gpu-fallback (Only used when targeting DLA
3838
(device-type)) Lets engine run layers on
3939
GPU if they are not supported on DLA
40+
--allow-torch-fallback Enable layers to run in torch
41+
if they are not supported in TensorRT
4042
-p[precision],
4143
--default-op-precision=[precision]
4244
Default operating precision for the
4345
engine (Int8 requires a
4446
calibration-cache argument) [ float |
4547
float32 | f32 | half | float16 | f16 |
4648
int8 | i8 ] (default: float)
49+
--forced-fallback-ops List of operators in the graph that
50+
should be forced to fallback to Pytorch for execution
4751
-d[type], --device-type=[type] The type of device the engine should be
4852
built for [ gpu | dla ] (default: gpu)
4953
--engine-capability=[capability] The type of device the engine should be
@@ -84,4 +88,4 @@ trtorchc [input_file_path] [output_file_path]
8488
e.g.
8589
```
8690
trtorchc tests/modules/ssd_traced.jit.pt ssd_trt.ts "[(1,3,300,300); (1,3,512,512); (1, 3, 1024, 1024)]" -p f16
87-
```
91+
```

Diff for: cpp/trtorchc/main.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ int main(int argc, char** argv) {
163163
"(Only used when targeting DLA (device-type)) Lets engine run layers on GPU if they are not supported on DLA",
164164
{"allow-gpu-fallback"});
165165

166+
args::Flag allow_torch_fallback(
167+
parser, "allow-torch-fallback", "Enable layers to run in torch if they are not supported in TensorRT", {"allow-torch-fallback"});
168+
166169
args::Flag disable_tf32(
167170
parser, "disable-tf32", "Prevent Float32 layers from using the TF32 data format", {"disable-tf32"});
168171

@@ -191,6 +194,11 @@ int main(int argc, char** argv) {
191194
"file_path",
192195
"Path to calibration cache file to use for post training quantization",
193196
{"calibration-cache-file"});
197+
args::ValueFlag<std::string> forced_fallback_ops(
198+
parser,
199+
"forced_fallback_ops",
200+
"List of operators in the graph that should be forced to fallback to Pytorch for execution.",
201+
{"ffo", "forced-fallback-ops"});
194202
args::ValueFlag<int> num_min_timing_iters(
195203
parser, "num_iters", "Number of minimization timing iterations used to select kernels", {"num-min-timing-iter"});
196204
args::ValueFlag<int> num_avg_timing_iters(
@@ -266,6 +274,10 @@ int main(int argc, char** argv) {
266274
compile_settings.device.allow_gpu_fallback = true;
267275
}
268276

277+
if (allow_torch_fallback) {
278+
compile_settings.torch_fallback = trtorch::CompileSpec::TorchFallback(true);
279+
}
280+
269281
if (disable_tf32) {
270282
compile_settings.disable_tf32 = true;
271283
}
@@ -277,6 +289,20 @@ int main(int argc, char** argv) {
277289

278290
auto calibrator = trtorch::ptq::make_int8_cache_calibrator(calibration_cache_file_path);
279291

292+
if (forced_fallback_ops) {
293+
std::string fallback_ops = args::get(forced_fallback_ops);
294+
if (!allow_torch_fallback){
295+
trtorch::logging::log(
296+
trtorch::logging::Level::kERROR,
297+
"Forced fallback ops provided but allow_torch_fallback is False. Please use --allow_torch_fallback to enable automatic fallback of operators.");
298+
}
299+
std::string op;
300+
std::stringstream ss(fallback_ops);
301+
while (getline(ss, op, ',')) {
302+
compile_settings.torch_fallback.forced_fallback_ops.push_back(op);
303+
}
304+
}
305+
280306
if (op_precision) {
281307
auto precision = args::get(op_precision);
282308
std::transform(

0 commit comments

Comments
 (0)