Skip to content

Commit 2b69742

Browse files
committed
fix: Fix TRT8 engine capability flags
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent e336630 commit 2b69742

File tree

3 files changed

+13
-13
lines changed

3 files changed

+13
-13
lines changed

Diff for: cpp/trtorchc/README.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ trtorchc [input_file_path] [output_file_path]
5959
--dla-core=[dla_core] DLACore id if running on available DLA
6060
(defaults to 0)
6161
--engine-capability=[capability] The type of device the engine should be
62-
built for [ default | safe_gpu |
63-
safe_dla ]
62+
built for [ standard | safety |
63+
dla_standalone ]
6464
--calibration-cache-file=[file_path]
6565
Path to calibration cache file to use
6666
for post training quantization

Diff for: cpp/trtorchc/main.cpp

+7-7
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ int main(int argc, char** argv) {
264264
args::ValueFlag<std::string> engine_capability(
265265
parser,
266266
"capability",
267-
"The type of device the engine should be built for [ default | safe_gpu | safe_dla ]",
267+
"The type of device the engine should be built for [ standard | safety | dla_standalone ]",
268268
{"engine-capability"});
269269

270270
args::ValueFlag<std::string> calibration_cache_file(
@@ -537,12 +537,12 @@ int main(int argc, char** argv) {
537537
auto capability = args::get(engine_capability);
538538
std::transform(
539539
capability.begin(), capability.end(), capability.begin(), [](unsigned char c) { return std::tolower(c); });
540-
if (capability == "default") {
541-
compile_settings.capability = trtorch::CompileSpec::EngineCapability::kDEFAULT;
542-
} else if (capability == "safe_gpu") {
543-
compile_settings.capability = trtorch::CompileSpec::EngineCapability::kSAFE_GPU;
544-
} else if (capability == "safe_dla") {
545-
compile_settings.capability = trtorch::CompileSpec::EngineCapability::kSAFE_DLA;
540+
if (capability == "standard") {
541+
compile_settings.capability = trtorch::CompileSpec::EngineCapability::kSTANDARD;
542+
} else if (capability == "safety") {
543+
compile_settings.capability = trtorch::CompileSpec::EngineCapability::kSAFETY;
544+
} else if (capability == "dla_standalone") {
545+
compile_settings.capability = trtorch::CompileSpec::EngineCapability::kDLA_STANDALONE;
546546
} else {
547547
trtorch::logging::log(
548548
trtorch::logging::Level::kERROR, "Invalid engine capability, options are [ default | safe_gpu | safe_dla ]");

Diff for: docsrc/tutorials/trtorchc.rst

+4-4
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r
4545
--ffo,
4646
--forced-fallback-ops List of operators in the graph that
4747
should be forced to fallback to Pytorch for execution
48-
48+
4949
--disable-tf32 Prevent Float32 layers from using the
5050
TF32 data format
5151
-p[precision...],
@@ -55,16 +55,16 @@ to standard TorchScript. Load with ``torch.jit.load()`` and run like you would r
5555
calibration-cache argument) [ float |
5656
float32 | f32 | half | float16 | f16 |
5757
int8 | i8 ] (default: float)
58-
58+
5959
-d[type], --device-type=[type] The type of device the engine should be
6060
built for [ gpu | dla ] (default: gpu)
6161
--gpu-id=[gpu_id] GPU id if running on multi-GPU platform
6262
(defaults to 0)
6363
--dla-core=[dla_core] DLACore id if running on available DLA
6464
(defaults to 0)
6565
--engine-capability=[capability] The type of device the engine should be
66-
built for [ default | safe_gpu |
67-
safe_dla ]
66+
built for [ standard | safety |
67+
dla_standalone ]
6868
--calibration-cache-file=[file_path]
6969
Path to calibration cache file to use
7070
for post training quantization

0 commit comments

Comments
 (0)