@@ -52,7 +52,7 @@ trtorch::CompileSpec::TensorFormat parseTensorFormat(std::string str) {
52
52
} else {
53
53
trtorch::logging::log (
54
54
trtorch::logging::Level::kERROR ,
55
- " Invalid tensor format, options are [ linear | nchw | chw | contiguous | nhwc | hwc | channels_last ]" );
55
+ " Invalid tensor format, options are [ linear | nchw | chw | contiguous | nhwc | hwc | channels_last ], found: " + str );
56
56
return trtorch::CompileSpec::TensorFormat::kUnknown ;
57
57
}
58
58
}
@@ -73,7 +73,7 @@ trtorch::CompileSpec::DataType parseDataType(std::string dtype_str) {
73
73
} else {
74
74
trtorch::logging::log (
75
75
trtorch::logging::Level::kERROR ,
76
- " Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 | int | int32 | i32 | bool | b]" );
76
+ " Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 | int | int32 | i32 | bool | b], found: " + dtype_str );
77
77
return trtorch::CompileSpec::DataType::kUnknown ;
78
78
}
79
79
}
@@ -221,8 +221,8 @@ int main(int argc, char** argv) {
221
221
" type" ,
222
222
" The type of device the engine should be built for [ gpu | dla ] (default: gpu)" ,
223
223
{' d' , " device-type" });
224
- args::ValueFlag<int > gpu_id (parser, " gpu_id" , " GPU id if running on multi-GPU platform (defaults to 0)" , {" gpu-id" });
225
- args::ValueFlag<int > dla_core (
224
+ args::ValueFlag<uint64_t > gpu_id (parser, " gpu_id" , " GPU id if running on multi-GPU platform (defaults to 0)" , {" gpu-id" });
225
+ args::ValueFlag<uint64_t > dla_core (
226
226
parser, " dla_core" , " DLACore id if running on available DLA (defaults to 0)" , {" dla-core" });
227
227
228
228
args::ValueFlag<std::string> engine_capability (
@@ -243,13 +243,13 @@ int main(int argc, char** argv) {
243
243
" Whether to treat input file as a serialized TensorRT engine and embed it into a TorchScript module (device spec must be provided)" ,
244
244
{" embed-engine" });
245
245
246
- args::ValueFlag<int > num_min_timing_iters (
246
+ args::ValueFlag<uint64_t > num_min_timing_iters (
247
247
parser, " num_iters" , " Number of minimization timing iterations used to select kernels" , {" num-min-timing-iter" });
248
- args::ValueFlag<int > num_avg_timing_iters (
248
+ args::ValueFlag<uint64_t > num_avg_timing_iters (
249
249
parser, " num_iters" , " Number of averaging timing iterations used to select kernels" , {" num-avg-timing-iters" });
250
- args::ValueFlag<int > workspace_size (
250
+ args::ValueFlag<uint64_t > workspace_size (
251
251
parser, " workspace_size" , " Maximum size of workspace given to TensorRT" , {" workspace-size" });
252
- args::ValueFlag<int > max_batch_size (
252
+ args::ValueFlag<uint64_t > max_batch_size (
253
253
parser, " max_batch_size" , " Maximum batch size (must be >= 1 to be set, 0 means not set)" , {" max-batch-size" });
254
254
args::ValueFlag<double > threshold (
255
255
parser,
@@ -276,8 +276,8 @@ int main(int argc, char** argv) {
276
276
std::cout << parser;
277
277
return 0 ;
278
278
} catch (args::ParseError e) {
279
- std::cerr << e.what () << std::endl ;
280
- std::cerr << parser;
279
+ trtorch::logging::log (trtorch::logging::Level:: kERROR , e.what ()) ;
280
+ std::cerr << std::endl << parser;
281
281
return 1 ;
282
282
}
283
283
@@ -309,13 +309,13 @@ int main(int argc, char** argv) {
309
309
auto parsed_dtype = parseDataType (dtype);
310
310
if (parsed_dtype == trtorch::CompileSpec::DataType::kUnknown ) {
311
311
trtorch::logging::log (trtorch::logging::Level::kERROR , " Invalid datatype for input specification " + spec);
312
- std::cerr << parser;
312
+ std::cerr << std::endl << parser;
313
313
exit (1 );
314
314
}
315
315
auto parsed_format = parseTensorFormat (format);
316
316
if (parsed_format == trtorch::CompileSpec::TensorFormat::kUnknown ) {
317
317
trtorch::logging::log (trtorch::logging::Level::kERROR , " Invalid format for input specification " + spec);
318
- std::cerr << parser;
318
+ std::cerr << std::endl << parser;
319
319
exit (1 );
320
320
}
321
321
if (shapes.rfind (" (" , 0 ) == 0 ) {
@@ -326,7 +326,7 @@ int main(int argc, char** argv) {
326
326
trtorch::CompileSpec::Input (dyn_shapes[0 ], dyn_shapes[1 ], dyn_shapes[2 ], parsed_dtype, parsed_format));
327
327
} else {
328
328
trtorch::logging::log (trtorch::logging::Level::kERROR , spec_err_str);
329
- std::cerr << parser;
329
+ std::cerr << std::endl << parser;
330
330
exit (1 );
331
331
}
332
332
// THERE IS NO SPEC FOR FORMAT
@@ -337,7 +337,7 @@ int main(int argc, char** argv) {
337
337
auto parsed_dtype = parseDataType (dtype);
338
338
if (parsed_dtype == trtorch::CompileSpec::DataType::kUnknown ) {
339
339
trtorch::logging::log (trtorch::logging::Level::kERROR , " Invalid datatype for input specification " + spec);
340
- std::cerr << parser;
340
+ std::cerr << std::endl << parser;
341
341
exit (1 );
342
342
}
343
343
if (shapes.rfind (" (" , 0 ) == 0 ) {
@@ -347,7 +347,7 @@ int main(int argc, char** argv) {
347
347
ranges.push_back (trtorch::CompileSpec::Input (dyn_shapes[0 ], dyn_shapes[1 ], dyn_shapes[2 ], parsed_dtype));
348
348
} else {
349
349
trtorch::logging::log (trtorch::logging::Level::kERROR , spec_err_str);
350
- std::cerr << parser;
350
+ std::cerr << std::endl << parser;
351
351
exit (1 );
352
352
}
353
353
}
@@ -359,7 +359,7 @@ int main(int argc, char** argv) {
359
359
auto parsed_format = parseTensorFormat (format);
360
360
if (parsed_format == trtorch::CompileSpec::TensorFormat::kUnknown ) {
361
361
trtorch::logging::log (trtorch::logging::Level::kERROR , " Invalid format for input specification " + spec);
362
- std::cerr << parser;
362
+ std::cerr << std::endl << parser;
363
363
exit (1 );
364
364
}
365
365
if (shapes.rfind (" (" , 0 ) == 0 ) {
@@ -369,7 +369,7 @@ int main(int argc, char** argv) {
369
369
ranges.push_back (trtorch::CompileSpec::Input (dyn_shapes[0 ], dyn_shapes[1 ], dyn_shapes[2 ], parsed_format));
370
370
} else {
371
371
trtorch::logging::log (trtorch::logging::Level::kERROR , spec_err_str);
372
- std::cerr << parser;
372
+ std::cerr << std::endl << parser;
373
373
exit (1 );
374
374
}
375
375
// JUST SHAPE USE DEFAULT DTYPE
@@ -381,7 +381,7 @@ int main(int argc, char** argv) {
381
381
ranges.push_back (trtorch::CompileSpec::Input (dyn_shapes[0 ], dyn_shapes[1 ], dyn_shapes[2 ]));
382
382
} else {
383
383
trtorch::logging::log (trtorch::logging::Level::kERROR , spec_err_str);
384
- std::cerr << parser;
384
+ std::cerr << std::endl << parser;
385
385
exit (1 );
386
386
}
387
387
}
@@ -430,14 +430,15 @@ int main(int argc, char** argv) {
430
430
trtorch::logging::log (
431
431
trtorch::logging::Level::kERROR ,
432
432
" If targeting INT8 default operating precision with trtorchc, a calibration cache file must be provided" );
433
- std::cerr << parser;
434
433
return 1 ;
435
434
}
436
435
} else {
436
+ std::stringstream ss;
437
+ ss << " Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 ], found: " ;
438
+ ss << dtype;
437
439
trtorch::logging::log (
438
- trtorch::logging::Level::kERROR ,
439
- " Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 ]" );
440
- std::cerr << parser;
440
+ trtorch::logging::Level::kERROR , ss.str ());
441
+ std::cerr << std::endl << parser;
441
442
return 1 ;
442
443
}
443
444
}
@@ -460,8 +461,8 @@ int main(int argc, char** argv) {
460
461
compile_settings.device .dla_core = args::get (dla_core);
461
462
}
462
463
} else {
463
- trtorch::logging::log (trtorch::logging::Level::kERROR , " Invalid device type, options are [ gpu | dla ]" );
464
- std::cerr << parser;
464
+ trtorch::logging::log (trtorch::logging::Level::kERROR , " Invalid device type, options are [ gpu | dla ] found: " + device );
465
+ std::cerr << std::endl << parser;
465
466
return 1 ;
466
467
}
467
468
}
@@ -479,7 +480,7 @@ int main(int argc, char** argv) {
479
480
} else {
480
481
trtorch::logging::log (
481
482
trtorch::logging::Level::kERROR , " Invalid engine capability, options are [ default | safe_gpu | safe_dla ]" );
482
- std::cerr << parser;
483
+ std::cerr << std::endl << parser;
483
484
return 1 ;
484
485
}
485
486
}
@@ -517,7 +518,6 @@ int main(int argc, char** argv) {
517
518
mod = torch::jit::load (real_input_path);
518
519
} catch (const c10::Error& e) {
519
520
trtorch::logging::log (trtorch::logging::Level::kERROR , " Error loading the model (path may be incorrect)" );
520
- std::cerr << parser;
521
521
return 1 ;
522
522
}
523
523
0 commit comments