@@ -163,6 +163,9 @@ int main(int argc, char** argv) {
163
163
" (Only used when targeting DLA (device-type)) Lets engine run layers on GPU if they are not supported on DLA" ,
164
164
{" allow-gpu-fallback" });
165
165
166
+ args::Flag allow_torch_fallback (
167
+ parser, " allow-torch-fallback" , " Enable layers to run in torch if they are not supported in TensorRT" , {" allow-torch-fallback" });
168
+
166
169
args::Flag disable_tf32 (
167
170
parser, " disable-tf32" , " Prevent Float32 layers from using the TF32 data format" , {" disable-tf32" });
168
171
@@ -191,6 +194,11 @@ int main(int argc, char** argv) {
191
194
" file_path" ,
192
195
" Path to calibration cache file to use for post training quantization" ,
193
196
{" calibration-cache-file" });
197
+ args::ValueFlag<std::string> forced_fallback_ops (
198
+ parser,
199
+ " forced_fallback_ops" ,
200
+ " List of operators in the graph that should be forced to fallback to Pytorch for execution." ,
201
+ {" ffo" , " forced-fallback-ops" });
194
202
args::ValueFlag<int > num_min_timing_iters (
195
203
parser, " num_iters" , " Number of minimization timing iterations used to select kernels" , {" num-min-timing-iter" });
196
204
args::ValueFlag<int > num_avg_timing_iters (
@@ -266,6 +274,10 @@ int main(int argc, char** argv) {
266
274
compile_settings.device .allow_gpu_fallback = true ;
267
275
}
268
276
277
+ if (allow_torch_fallback) {
278
+ compile_settings.torch_fallback = trtorch::CompileSpec::TorchFallback (true );
279
+ }
280
+
269
281
if (disable_tf32) {
270
282
compile_settings.disable_tf32 = true ;
271
283
}
@@ -277,6 +289,20 @@ int main(int argc, char** argv) {
277
289
278
290
auto calibrator = trtorch::ptq::make_int8_cache_calibrator (calibration_cache_file_path);
279
291
292
+ if (forced_fallback_ops) {
293
+ std::string fallback_ops = args::get (forced_fallback_ops);
294
+ if (!allow_torch_fallback){
295
+ trtorch::logging::log (
296
+ trtorch::logging::Level::kERROR ,
297
+ " Forced fallback ops provided but allow_torch_fallback is False. Please use --allow_torch_fallback to enable automatic fallback of operators." );
298
+ }
299
+ std::string op;
300
+ std::stringstream ss (fallback_ops);
301
+ while (getline (ss, op, ' ,' )) {
302
+ compile_settings.torch_fallback .forced_fallback_ops .push_back (op);
303
+ }
304
+ }
305
+
280
306
if (op_precision) {
281
307
auto precision = args::get (op_precision);
282
308
std::transform (
0 commit comments