Skip to content

Commit 14650d1

Browse files
committed
fix: Use user provided dtype when we can't infer it from the graph
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent e38056b commit 14650d1

File tree

4 files changed

+12
-6
lines changed

4 files changed

+12
-6
lines changed

Diff for: core/compiler.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,8 @@ void MapInputsAndDetermineDTypes(
328328
spec.dtype = nvinfer1::DataType::kFLOAT;
329329
} else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) {
330330
if (!est_type_opt) {
331-
LOG_INFO("Cannot infer input tensor dtype in graph, unable to verify user input dtype settings");
331+
LOG_INFO("Cannot infer input tensor dtype in graph. Using user provided input dtype settings");
332+
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
332333
} else {
333334
if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) {
334335
std::stringstream ss;

Diff for: core/ir/Input.cpp

+9-2
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ bool valid_dtype_format_combo(nvinfer1::DataType dtype, nvinfer1::TensorFormat f
4040
default:
4141
return false;
4242
}
43+
case nvinfer1::DataType::kBOOL: // Supports Linear (NCHW)
44+
switch (format) {
45+
case nvinfer1::TensorFormat::kLINEAR:
46+
return true;
47+
default:
48+
return false;
49+
}
4350
default:
4451
return false;
4552
}
@@ -48,7 +55,7 @@ bool valid_dtype_format_combo(nvinfer1::DataType dtype, nvinfer1::TensorFormat f
4855
bool valid_input_dtype(nvinfer1::DataType dtype) {
4956
switch (dtype) {
5057
case nvinfer1::DataType::kBOOL:
51-
return false;
58+
return true;
5259
case nvinfer1::DataType::kFLOAT:
5360
return true;
5461
case nvinfer1::DataType::kHALF:
@@ -153,4 +160,4 @@ std::ostream& operator<<(std::ostream& os, const Input& input) {
153160

154161
} // namespace ir
155162
} // namespace core
156-
} // namespace torch_tensorrt
163+
} // namespace torch_tensorrt

Diff for: py/setup.py

-2
Original file line numberDiff line numberDiff line change
@@ -239,8 +239,6 @@ def run(self):
239239
libraries=["torchtrt"],
240240
include_dirs=[
241241
dir_path + "torch_tensorrt/csrc", dir_path + "torch_tensorrt/include",
242-
dir_path + "/../bazel-TRTorch/external/tensorrt/include",
243-
dir_path + "/../bazel-Torch-TensorRT-Preview/external/tensorrt/include",
244242
dir_path + "/../bazel-Torch-TensorRT/external/tensorrt/include", dir_path + "/../"
245243
],
246244
extra_compile_args=[

Diff for: py/torch_tensorrt/csrc/torch_tensorrt_py.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ PYBIND11_MODULE(_C, m) {
185185
.value("float16", DataType::kHalf, "16 bit floating point number")
186186
.value("int8", DataType::kChar, "8 bit integer number")
187187
.value("int32", DataType::kInt32, "32 bit integer number")
188-
.value("bool", DataType::kChar, "Boolean value")
188+
.value("bool", DataType::kBool, "Boolean value")
189189
.value("unknown", DataType::kUnknown, "Unknown data type")
190190
.export_values();
191191

0 commit comments

Comments
 (0)