Skip to content

Commit 652fb13

Browse files
committed
feat(trtorchc): Adding more dtype aliases
Adding more aliases for f32 and f16 to make it easier on users Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent e1e7812 commit 652fb13

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

Diff for: cpp/trtorchc/main.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,9 @@ trtorch::CompileSpec::TensorFormat parseTensorFormat(std::string str) {
6060
trtorch::CompileSpec::DataType parseDataType(std::string dtype_str) {
6161
std::transform(
6262
dtype_str.begin(), dtype_str.end(), dtype_str.begin(), [](unsigned char c) { return std::tolower(c); });
63-
if (dtype_str == "float" || dtype_str == "float32" || dtype_str == "f32") {
63+
if (dtype_str == "float" || dtype_str == "float32" || dtype_str == "f32" || dtype_str == "fp32") {
6464
return trtorch::CompileSpec::DataType::kFloat;
65-
} else if (dtype_str == "half" || dtype_str == "float16" || dtype_str == "f16") {
65+
} else if (dtype_str == "half" || dtype_str == "float16" || dtype_str == "f16" || dtype_str == "fp16") {
6666
return trtorch::CompileSpec::DataType::kHalf;
6767
} else if (dtype_str == "char" || dtype_str == "int8" || dtype_str == "i8") {
6868
return trtorch::CompileSpec::DataType::kChar;
@@ -73,7 +73,7 @@ trtorch::CompileSpec::DataType parseDataType(std::string dtype_str) {
7373
} else {
7474
trtorch::logging::log(
7575
trtorch::logging::Level::kERROR,
76-
"Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 | int | int32 | i32 | bool | b], found: " + dtype_str);
76+
"Invalid precision, options are [ float | float32 | fp32 | f32 | half | float16 | fp16 | f16 | char | int8 | i8 | int | int32 | i32 | bool | b], found: " + dtype_str);
7777
return trtorch::CompileSpec::DataType::kUnknown;
7878
}
7979
}
@@ -214,7 +214,7 @@ int main(int argc, char** argv) {
214214
args::ValueFlagList<std::string> enabled_precision(
215215
parser,
216216
"precision",
217-
"(Repeatable) Enabling an operating precision for kernels to use when building the engine (Int8 requires a calibration-cache argument) [ float | float32 | f32 | half | float16 | f16 | int8 | i8 ] (default: float)",
217+
"(Repeatable) Enabling an operating precision for kernels to use when building the engine (Int8 requires a calibration-cache argument) [ float | float32 | f32 | fp32 | half | float16 | f16 | fp16 | int8 | i8 | char ] (default: float)",
218218
{'p', "enabled-precison"});
219219
args::ValueFlag<std::string> device_type(
220220
parser,
@@ -434,7 +434,7 @@ int main(int argc, char** argv) {
434434
}
435435
} else {
436436
std::stringstream ss;
437-
ss << "Invalid precision, options are [ float | float32 | f32 | half | float16 | f16 | char | int8 | i8 ], found: ";
437+
ss << "Invalid precision given for enabled kernel precision, options are [ float | float32 | f32 | fp32 | half | float16 | f16 | fp16 | char | int8 | i8 ], found: ";
438438
ss << dtype;
439439
trtorch::logging::log(
440440
trtorch::logging::Level::kERROR, ss.str());

0 commit comments

Comments
 (0)