Skip to content

Commit e1e7812

Browse files
committed
fix(trtorchc): Allow for workspaces larger than 2G and better debugging
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 24de61b commit e1e7812

File tree

1 file changed

+26
-26
lines changed

1 file changed

+26
-26
lines changed

Diff for: cpp/trtorchc/main.cpp

+26-26
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ trtorch::CompileSpec::TensorFormat parseTensorFormat(std::string str) {
5252
} else {
5353
trtorch::logging::log(
5454
trtorch::logging::Level::kERROR,
55-
"Invalid tensor format, options are [ linear | nchw | chw | contiguous | nhwc | hwc | channels_last ]");
55+
"Invalid tensor format, options are [ linear | nchw | chw | contiguous | nhwc | hwc | channels_last ], found: " + str);
5656
return trtorch::CompileSpec::TensorFormat::kUnknown;
5757
}
5858
}
@@ -73,7 +73,7 @@ trtorch::CompileSpec::DataType parseDataType(std::string dtype_str) {
7373
} else {
7474
trtorch::logging::log(
7575
trtorch::logging::Level::kERROR,
76-
"Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 | int | int32 | i32 | bool | b]");
76+
"Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 | int | int32 | i32 | bool | b], found: " + dtype_str);
7777
return trtorch::CompileSpec::DataType::kUnknown;
7878
}
7979
}
@@ -221,8 +221,8 @@ int main(int argc, char** argv) {
221221
"type",
222222
"The type of device the engine should be built for [ gpu | dla ] (default: gpu)",
223223
{'d', "device-type"});
224-
args::ValueFlag<int> gpu_id(parser, "gpu_id", "GPU id if running on multi-GPU platform (defaults to 0)", {"gpu-id"});
225-
args::ValueFlag<int> dla_core(
224+
args::ValueFlag<uint64_t> gpu_id(parser, "gpu_id", "GPU id if running on multi-GPU platform (defaults to 0)", {"gpu-id"});
225+
args::ValueFlag<uint64_t> dla_core(
226226
parser, "dla_core", "DLACore id if running on available DLA (defaults to 0)", {"dla-core"});
227227

228228
args::ValueFlag<std::string> engine_capability(
@@ -243,13 +243,13 @@ int main(int argc, char** argv) {
243243
"Whether to treat input file as a serialized TensorRT engine and embed it into a TorchScript module (device spec must be provided)",
244244
{"embed-engine"});
245245

246-
args::ValueFlag<int> num_min_timing_iters(
246+
args::ValueFlag<uint64_t> num_min_timing_iters(
247247
parser, "num_iters", "Number of minimization timing iterations used to select kernels", {"num-min-timing-iter"});
248-
args::ValueFlag<int> num_avg_timing_iters(
248+
args::ValueFlag<uint64_t> num_avg_timing_iters(
249249
parser, "num_iters", "Number of averaging timing iterations used to select kernels", {"num-avg-timing-iters"});
250-
args::ValueFlag<int> workspace_size(
250+
args::ValueFlag<uint64_t> workspace_size(
251251
parser, "workspace_size", "Maximum size of workspace given to TensorRT", {"workspace-size"});
252-
args::ValueFlag<int> max_batch_size(
252+
args::ValueFlag<uint64_t> max_batch_size(
253253
parser, "max_batch_size", "Maximum batch size (must be >= 1 to be set, 0 means not set)", {"max-batch-size"});
254254
args::ValueFlag<double> threshold(
255255
parser,
@@ -276,8 +276,8 @@ int main(int argc, char** argv) {
276276
std::cout << parser;
277277
return 0;
278278
} catch (args::ParseError e) {
279-
std::cerr << e.what() << std::endl;
280-
std::cerr << parser;
279+
trtorch::logging::log(trtorch::logging::Level::kERROR, e.what());
280+
std::cerr << std::endl << parser;
281281
return 1;
282282
}
283283

@@ -309,13 +309,13 @@ int main(int argc, char** argv) {
309309
auto parsed_dtype = parseDataType(dtype);
310310
if (parsed_dtype == trtorch::CompileSpec::DataType::kUnknown) {
311311
trtorch::logging::log(trtorch::logging::Level::kERROR, "Invalid datatype for input specification " + spec);
312-
std::cerr << parser;
312+
std::cerr << std::endl << parser;
313313
exit(1);
314314
}
315315
auto parsed_format = parseTensorFormat(format);
316316
if (parsed_format == trtorch::CompileSpec::TensorFormat::kUnknown) {
317317
trtorch::logging::log(trtorch::logging::Level::kERROR, "Invalid format for input specification " + spec);
318-
std::cerr << parser;
318+
std::cerr << std::endl << parser;
319319
exit(1);
320320
}
321321
if (shapes.rfind("(", 0) == 0) {
@@ -326,7 +326,7 @@ int main(int argc, char** argv) {
326326
trtorch::CompileSpec::Input(dyn_shapes[0], dyn_shapes[1], dyn_shapes[2], parsed_dtype, parsed_format));
327327
} else {
328328
trtorch::logging::log(trtorch::logging::Level::kERROR, spec_err_str);
329-
std::cerr << parser;
329+
std::cerr << std::endl << parser;
330330
exit(1);
331331
}
332332
// THERE IS NO SPEC FOR FORMAT
@@ -337,7 +337,7 @@ int main(int argc, char** argv) {
337337
auto parsed_dtype = parseDataType(dtype);
338338
if (parsed_dtype == trtorch::CompileSpec::DataType::kUnknown) {
339339
trtorch::logging::log(trtorch::logging::Level::kERROR, "Invalid datatype for input specification " + spec);
340-
std::cerr << parser;
340+
std::cerr << std::endl << parser;
341341
exit(1);
342342
}
343343
if (shapes.rfind("(", 0) == 0) {
@@ -347,7 +347,7 @@ int main(int argc, char** argv) {
347347
ranges.push_back(trtorch::CompileSpec::Input(dyn_shapes[0], dyn_shapes[1], dyn_shapes[2], parsed_dtype));
348348
} else {
349349
trtorch::logging::log(trtorch::logging::Level::kERROR, spec_err_str);
350-
std::cerr << parser;
350+
std::cerr << std::endl << parser;
351351
exit(1);
352352
}
353353
}
@@ -359,7 +359,7 @@ int main(int argc, char** argv) {
359359
auto parsed_format = parseTensorFormat(format);
360360
if (parsed_format == trtorch::CompileSpec::TensorFormat::kUnknown) {
361361
trtorch::logging::log(trtorch::logging::Level::kERROR, "Invalid format for input specification " + spec);
362-
std::cerr << parser;
362+
std::cerr << std::endl << parser;
363363
exit(1);
364364
}
365365
if (shapes.rfind("(", 0) == 0) {
@@ -369,7 +369,7 @@ int main(int argc, char** argv) {
369369
ranges.push_back(trtorch::CompileSpec::Input(dyn_shapes[0], dyn_shapes[1], dyn_shapes[2], parsed_format));
370370
} else {
371371
trtorch::logging::log(trtorch::logging::Level::kERROR, spec_err_str);
372-
std::cerr << parser;
372+
std::cerr << std::endl << parser;
373373
exit(1);
374374
}
375375
// JUST SHAPE USE DEFAULT DTYPE
@@ -381,7 +381,7 @@ int main(int argc, char** argv) {
381381
ranges.push_back(trtorch::CompileSpec::Input(dyn_shapes[0], dyn_shapes[1], dyn_shapes[2]));
382382
} else {
383383
trtorch::logging::log(trtorch::logging::Level::kERROR, spec_err_str);
384-
std::cerr << parser;
384+
std::cerr << std::endl << parser;
385385
exit(1);
386386
}
387387
}
@@ -430,14 +430,15 @@ int main(int argc, char** argv) {
430430
trtorch::logging::log(
431431
trtorch::logging::Level::kERROR,
432432
"If targeting INT8 default operating precision with trtorchc, a calibration cache file must be provided");
433-
std::cerr << parser;
434433
return 1;
435434
}
436435
} else {
436+
std::stringstream ss;
437+
ss << "Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 ], found: ";
438+
ss << dtype;
437439
trtorch::logging::log(
438-
trtorch::logging::Level::kERROR,
439-
"Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 ]");
440-
std::cerr << parser;
440+
trtorch::logging::Level::kERROR, ss.str());
441+
std::cerr << std::endl << parser;
441442
return 1;
442443
}
443444
}
@@ -460,8 +461,8 @@ int main(int argc, char** argv) {
460461
compile_settings.device.dla_core = args::get(dla_core);
461462
}
462463
} else {
463-
trtorch::logging::log(trtorch::logging::Level::kERROR, "Invalid device type, options are [ gpu | dla ]");
464-
std::cerr << parser;
464+
trtorch::logging::log(trtorch::logging::Level::kERROR, "Invalid device type, options are [ gpu | dla ] found: " + device);
465+
std::cerr << std::endl << parser;
465466
return 1;
466467
}
467468
}
@@ -479,7 +480,7 @@ int main(int argc, char** argv) {
479480
} else {
480481
trtorch::logging::log(
481482
trtorch::logging::Level::kERROR, "Invalid engine capability, options are [ default | safe_gpu | safe_dla ]");
482-
std::cerr << parser;
483+
std::cerr << std::endl << parser;
483484
return 1;
484485
}
485486
}
@@ -517,7 +518,6 @@ int main(int argc, char** argv) {
517518
mod = torch::jit::load(real_input_path);
518519
} catch (const c10::Error& e) {
519520
trtorch::logging::log(trtorch::logging::Level::kERROR, "Error loading the model (path may be incorrect)");
520-
std::cerr << parser;
521521
return 1;
522522
}
523523

0 commit comments

Comments
 (0)