Skip to content

Commit a3f4a3c

Browse files
peri044narendasan
authored andcommitted
feat: Add support for providing input datatypes in TRTorch
Signed-off-by: Dheeraj Peri <[email protected]> Signed-off-by: Dheeraj Peri <[email protected]>
1 parent bdaacf1 commit a3f4a3c

File tree

12 files changed

+125
-28
lines changed

12 files changed

+125
-28
lines changed

Diff for: README.md

+7-2
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,11 @@ More Information / System Architecture:
2020
...
2121
auto compile_settings = trtorch::CompileSpec(dims);
2222
// FP16 execution
23-
compile_settings.op_precision = torch::kFloat;
23+
compile_settings.op_precision = torch::kHalf;
24+
// Set input datatypes. Allowerd options torch::{kFloat, kHalf, kChar, kInt32, kBool}
25+
// Size of input_dtypes should match number of inputs to the network.
26+
// If input_dtypes is not set, default precision for input tensors would be float32
27+
compile_spec.input_dtypes = {torch::kHalf};
2428
// Compile module
2529
auto trt_mod = trtorch::CompileGraph(ts_mod, compile_settings);
2630
// Run like normal
@@ -43,7 +47,8 @@ compile_settings = {
4347
"max": [1, 3, 1024, 1024]
4448
}, # For static size [1, 3, 224, 224]
4549
],
46-
"op_precision": torch.half # Run with FP16
50+
"op_precision": torch.half, # Run with FP16
51+
"input_dtypes": [torch.half] # Datatype of input tensor. Allowed options torch.(float|half|int8|int32|bool)
4752
}
4853
4954
trt_ts_module = trtorch.compile(torch_script_module, compile_settings)

Diff for: core/conversion/conversion.cpp

+16-3
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,27 @@ void AddInputs(
150150

151151
auto profile = ctx->builder->createOptimizationProfile();
152152

153+
TRTORCH_CHECK(
154+
ctx->input_dtypes.size() == 0 || ctx->input_dtypes.size() == input_tensors.size(),
155+
"Number of input_dtypes : " << ctx->input_dtypes.size()
156+
<< " should either be 0 or equal to number of input_tensors which is "
157+
<< input_tensors.size() << " (conversion.AddInputs)");
158+
159+
// If the input_dtypes is not provided, assume all the input tensors to be in float32
160+
if (ctx->input_dtypes.size() == 0) {
161+
LOG_DEBUG("Input datatypes are not provided explicitly. Default float32 datatype is being used for all inputs");
162+
ctx->input_dtypes = std::vector<nvinfer1::DataType>{input_tensors.size(), nvinfer1::DataType::kFLOAT};
163+
}
164+
153165
for (size_t i = 0; i < input_tensors.size(); i++) {
154166
auto in = input_tensors[i];
155167
auto dims = input_dims[i];
156168
std::string name = std::string("input_") + std::to_string(ctx->num_inputs);
157169
LOG_INFO(
158-
ctx->logger, "Adding Input " << in->debugName() << " named " << name << " in engine (conversion.AddInputs)");
159-
LOG_DEBUG(ctx->logger, "Input shape set to " << dims.input_shape);
160-
auto trt_in = ctx->net->addInput(name.c_str(), ctx->input_type, dims.input_shape);
170+
ctx->logger,
171+
"Adding Input " << in->debugName() << " named : " << name << ", shape: " << dims.input_shape
172+
<< ", dtype : " << ctx->input_dtypes[i] << " in engine (conversion.AddInputs)");
173+
auto trt_in = ctx->net->addInput(name.c_str(), ctx->input_dtypes[i], dims.input_shape);
161174
TRTORCH_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)");
162175

163176
profile->setDimensions(trt_in->getName(), nvinfer1::OptProfileSelector::kMIN, dims.min);

Diff for: core/conversion/conversionctx/ConversionCtx.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -61,26 +61,27 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
6161
case nvinfer1::DataType::kHALF:
6262
TRTORCH_CHECK(builder->platformHasFastFp16(), "Requested inference in FP16 but platform does not support FP16");
6363
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
64-
input_type = nvinfer1::DataType::kHALF;
6564
break;
6665
case nvinfer1::DataType::kINT8:
6766
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does not support INT8");
6867
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
6968
if (!settings.strict_types) {
7069
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
7170
}
72-
input_type = nvinfer1::DataType::kFLOAT;
7371
TRTORCH_CHECK(
7472
settings.calibrator != nullptr,
7573
"Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec struct with your calibrator");
7674
cfg->setInt8Calibrator(settings.calibrator);
7775
break;
7876
case nvinfer1::DataType::kFLOAT:
77+
case nvinfer1::DataType::kINT32:
78+
case nvinfer1::DataType::kBOOL:
7979
default:
80-
input_type = nvinfer1::DataType::kFLOAT;
8180
break;
8281
}
82+
8383
op_precision = settings.op_precision;
84+
input_dtypes = settings.input_dtypes;
8485

8586
if (settings.disable_tf32) {
8687
cfg->clearFlag(nvinfer1::BuilderFlag::kTF32);

Diff for: core/conversion/conversionctx/ConversionCtx.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ struct Device {
2424

2525
struct BuilderSettings {
2626
nvinfer1::DataType op_precision = nvinfer1::DataType::kFLOAT;
27+
std::vector<nvinfer1::DataType> input_dtypes;
2728
bool disable_tf32 = false;
2829
bool refit = false;
2930
bool debug = false;
@@ -57,7 +58,7 @@ struct ConversionCtx {
5758
nvinfer1::IBuilder* builder;
5859
nvinfer1::INetworkDefinition* net;
5960
nvinfer1::IBuilderConfig* cfg;
60-
nvinfer1::DataType input_type;
61+
std::vector<nvinfer1::DataType> input_dtypes;
6162
nvinfer1::DataType op_precision;
6263
BuilderSettings settings;
6364
util::logging::TRTorchLogger logger;

Diff for: cpp/api/include/trtorch/trtorch.h

+9
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ struct TRTORCH_API CompileSpec {
115115
kHalf,
116116
/// INT8
117117
kChar,
118+
/// INT32
119+
kInt32,
120+
/// Bool
121+
kBool,
118122
};
119123

120124
/**
@@ -239,6 +243,11 @@ struct TRTORCH_API CompileSpec {
239243
*/
240244
DataType op_precision = DataType::kFloat;
241245

246+
/**
247+
* Data types for input tensors
248+
*/
249+
std::vector<DataType> input_dtypes;
250+
242251
/**
243252
* Prevent Float32 layers from using TF32 data format
244253
*

Diff for: cpp/api/src/compile_spec.cpp

+30-12
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,26 @@
77

88
namespace trtorch {
99
CompileSpec::DataType::DataType(c10::ScalarType t) {
10-
TRTORCH_CHECK(t == at::kHalf || t == at::kFloat || t == at::kChar, "Data type is unsupported");
10+
TRTORCH_CHECK(
11+
t == at::kHalf || t == at::kFloat || t == at::kChar || t == at::kInt || t == at::kBool,
12+
"Data type is unsupported");
1113
switch (t) {
1214
case at::kHalf:
1315
value = DataType::kHalf;
1416
break;
17+
case at::kChar:
18+
value = DataType::kChar;
19+
break;
20+
case at::kInt:
21+
value = DataType::kInt32;
22+
break;
23+
case at::kBool:
24+
value = DataType::kBool;
25+
break;
1526
case at::kFloat:
1627
default:
1728
value = DataType::kFloat;
1829
break;
19-
case at::kChar:
20-
value = DataType::kChar;
2130
}
2231
}
2332

@@ -74,19 +83,28 @@ std::vector<core::ir::InputRange> to_vec_internal_input_ranges(std::vector<Compi
7483
return internal;
7584
}
7685

77-
core::CompileSpec to_internal_compile_spec(CompileSpec external) {
78-
core::CompileSpec internal(to_vec_internal_input_ranges(external.input_ranges));
79-
80-
switch (external.op_precision) {
86+
nvinfer1::DataType toTRTDataType(CompileSpec::DataType value) {
87+
switch (value) {
8188
case CompileSpec::DataType::kChar:
82-
internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kINT8;
83-
break;
89+
return nvinfer1::DataType::kINT8;
8490
case CompileSpec::DataType::kHalf:
85-
internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kHALF;
86-
break;
91+
return nvinfer1::DataType::kHALF;
92+
case CompileSpec::DataType::kInt32:
93+
return nvinfer1::DataType::kINT32;
94+
case CompileSpec::DataType::kBool:
95+
return nvinfer1::DataType::kBOOL;
8796
case CompileSpec::DataType::kFloat:
8897
default:
89-
internal.convert_info.engine_settings.op_precision = nvinfer1::DataType::kFLOAT;
98+
return nvinfer1::DataType::kFLOAT;
99+
}
100+
}
101+
102+
core::CompileSpec to_internal_compile_spec(CompileSpec external) {
103+
core::CompileSpec internal(to_vec_internal_input_ranges(external.input_ranges));
104+
105+
internal.convert_info.engine_settings.op_precision = toTRTDataType(external.op_precision);
106+
for (auto dtype : external.input_dtypes) {
107+
internal.convert_info.engine_settings.input_dtypes.push_back(toTRTDataType(dtype));
90108
}
91109

92110
internal.convert_info.engine_settings.disable_tf32 = external.disable_tf32;

Diff for: py/trtorch/_compile_spec.py

+25
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,24 @@ def _parse_torch_fallback(fallback_info: Dict[str, Any]) -> trtorch._C.TorchFall
140140

141141
return info
142142

143+
def _parse_input_dtypes(input_dtypes: List) -> List:
144+
parsed_input_dtypes = []
145+
for dtype in input_dtypes:
146+
if isinstance(dtype, torch.dtype):
147+
if dtype == torch.int8:
148+
parsed_input_dtypes.append(_types.dtype.int8)
149+
elif dtype == torch.half:
150+
parsed_input_dtypes.append(_types.dtype.half)
151+
elif dtype == torch.float:
152+
parsed_input_dtypes.append(_types.dtype.float)
153+
elif dtype == torch.int32:
154+
parsed_input_dtypes.append(_types.dtype.int32)
155+
elif dtype == torch.bool:
156+
parsed_input_dtypes.append(_types.dtype.bool)
157+
else:
158+
raise TypeError("Invalid input dtype. Supported input datatypes include float|half|int8|int32|bool), got: " + str(dtype))
159+
160+
return parsed_input_dtypes
143161

144162
def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
145163
info = trtorch._C.CompileSpec()
@@ -153,6 +171,9 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
153171
if "op_precision" in compile_spec:
154172
info.op_precision = _parse_op_precision(compile_spec["op_precision"])
155173

174+
if "input_dtypes" in compile_spec:
175+
info.input_dtypes = _parse_input_dtypes(compile_spec["input_dtypes"])
176+
156177
if "calibrator" in compile_spec:
157178
info.ptq_calibrator = compile_spec["calibrator"]
158179

@@ -237,6 +258,7 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
237258
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
238259
},
239260
"op_precision": torch.half, # Operating precision set to FP16
261+
# List of datatypes that should be configured for each input. Supported options torch.{float|half|int8|int32|bool}.
240262
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
241263
"refit": False, # enable refit
242264
"debug": False, # enable debuggable engine
@@ -288,6 +310,9 @@ def TensorRTCompileSpec(compile_spec: Dict[str, Any]) -> torch.classes.tensorrt.
288310
backend_spec._set_device(d)
289311
backend_spec._set_torch_fallback(torch_fallback)
290312
backend_spec._set_op_precision(int(parsed_spec.op_precision))
313+
for dtype in parsed_spec.input_dtypes:
314+
backend_spec._append_input_dtypes(int64_t(dtype))
315+
291316
backend_spec._set_disable_tf32(parsed_spec.disable_tf32)
292317
backend_spec._set_refit(parsed_spec.refit)
293318
backend_spec._set_debug(parsed_spec.debug)

Diff for: py/trtorch/_compiler.py

+2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def compile(module: torch.jit.ScriptModule, compile_spec: Any) -> torch.jit.Scri
4141
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
4242
},
4343
"op_precision": torch.half, # Operating precision set to FP16
44+
"input_dtypes": [torch.float32] # List of datatypes that should be configured for each input. Supported options torch.{float|half|int8|int32|bool}.
4445
"refit": false, # enable refit
4546
"debug": false, # enable debuggable engine
4647
"strict_types": false, # kernels should strictly run in operating precision
@@ -106,6 +107,7 @@ def convert_method_to_trt_engine(module: torch.jit.ScriptModule, method_name: st
106107
"allow_gpu_fallback": false, # (DLA only) Allow layers unsupported on DLA to run on GPU
107108
},
108109
"op_precision": torch.half, # Operating precision set to FP16
110+
# List of datatypes that should be configured for each input. Supported options torch.{float|half|int8|int32|bool}.
109111
"disable_tf32": False, # Force FP32 layers to use traditional as FP32 format vs the default behavior of rounding the inputs to 10-bit mantissas before multiplying, but accumulates the sum using 23-bit mantissas
110112
"refit": false, # enable refit
111113
"debug": false, # enable debuggable engine

Diff for: py/trtorch/csrc/register_tensorrt_classes.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ void RegisterTRTCompileSpec() {
4343
.def("_set_device", &trtorch::pyapi::CompileSpec::setDeviceIntrusive)
4444
.def("_set_torch_fallback", &trtorch::pyapi::CompileSpec::setTorchFallbackIntrusive)
4545
.def("_set_ptq_calibrator", &trtorch::pyapi::CompileSpec::setPTQCalibratorViaHandle)
46+
.def("_append_input_dtypes", &trtorch::pyapi::CompileSpec::appendInputDtypes)
4647
.def("__str__", &trtorch::pyapi::CompileSpec::stringify);
4748

4849
ADD_FIELD_GET_SET_REGISTRATION(TRTCompileSpecTSRegistration, trtorch::pyapi::CompileSpec, op_precision);

Diff for: py/trtorch/csrc/tensorrt_classes.cpp

+20-1
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ std::string to_str(DataType value) {
3030
return "Half";
3131
case DataType::kChar:
3232
return "Int8";
33+
case DataType::kInt32:
34+
return "Int32";
35+
case DataType::kBool:
36+
return "Bool";
3337
case DataType::kFloat:
3438
default:
3539
return "Float";
@@ -42,6 +46,10 @@ nvinfer1::DataType toTRTDataType(DataType value) {
4246
return nvinfer1::DataType::kINT8;
4347
case DataType::kHalf:
4448
return nvinfer1::DataType::kHALF;
49+
case DataType::kInt32:
50+
return nvinfer1::DataType::kINT32;
51+
case DataType::kBool:
52+
return nvinfer1::DataType::kBOOL;
4553
case DataType::kFloat:
4654
default:
4755
return nvinfer1::DataType::kFLOAT;
@@ -124,8 +132,15 @@ core::CompileSpec CompileSpec::toInternalCompileSpec() {
124132
for (auto i : input_ranges) {
125133
internal_input_ranges.push_back(i.toInternalInputRange());
126134
}
135+
136+
std::vector<nvinfer1::DataType> trt_input_dtypes;
137+
for (auto dtype : input_dtypes) {
138+
trt_input_dtypes.push_back(toTRTDataType(dtype));
139+
}
140+
127141
auto info = core::CompileSpec(internal_input_ranges);
128142
info.convert_info.engine_settings.op_precision = toTRTDataType(op_precision);
143+
info.convert_info.engine_settings.input_dtypes = trt_input_dtypes;
129144
info.convert_info.engine_settings.calibrator = ptq_calibrator;
130145
info.convert_info.engine_settings.disable_tf32 = disable_tf32;
131146
info.convert_info.engine_settings.refit = refit;
@@ -159,9 +174,13 @@ std::string CompileSpec::stringify() {
159174
for (auto i : input_ranges) {
160175
ss << i.to_str();
161176
}
162-
std::string enabled = torch_fallback.enabled ? "True" : "False";
163177
ss << " ]" << std::endl;
164178
ss << " \"Op Precision\": " << to_str(op_precision) << std::endl;
179+
ss << " \"Input dtypes\": [" << std::endl;
180+
for (auto i : input_dtypes) {
181+
ss << to_str(i);
182+
}
183+
ss << " ]" << std::endl;
165184
ss << " \"TF32 Disabled\": " << disable_tf32 << std::endl;
166185
ss << " \"Refit\": " << refit << std::endl;
167186
ss << " \"Debug\": " << debug << std::endl;

Diff for: py/trtorch/csrc/tensorrt_classes.h

+6-6
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,7 @@ struct InputRange : torch::CustomClassHolder {
4343
std::string to_str();
4444
};
4545

46-
enum class DataType : int8_t {
47-
kFloat,
48-
kHalf,
49-
kChar,
50-
};
46+
enum class DataType : int8_t { kFloat, kHalf, kChar, kInt32, kBool };
5147

5248
std::string to_str(DataType value);
5349
nvinfer1::DataType toTRTDataType(DataType value);
@@ -108,7 +104,9 @@ struct CompileSpec : torch::CustomClassHolder {
108104
void appendInputRange(const c10::intrusive_ptr<InputRange>& ir) {
109105
input_ranges.push_back(*ir);
110106
}
111-
107+
void appendInputDtypes(int64_t dtype) {
108+
input_dtypes.push_back(static_cast<DataType>(dtype));
109+
}
112110
int64_t getPTQCalibratorHandle() {
113111
return (int64_t)ptq_calibrator;
114112
}
@@ -120,6 +118,7 @@ struct CompileSpec : torch::CustomClassHolder {
120118
void setTorchFallbackIntrusive(const c10::intrusive_ptr<TorchFallback>& fb) {
121119
torch_fallback = *fb;
122120
}
121+
123122
void setPTQCalibratorViaHandle(int64_t handle) {
124123
ptq_calibrator = (nvinfer1::IInt8Calibrator*)handle;
125124
}
@@ -142,6 +141,7 @@ struct CompileSpec : torch::CustomClassHolder {
142141
std::vector<InputRange> input_ranges;
143142
nvinfer1::IInt8Calibrator* ptq_calibrator = nullptr;
144143
DataType op_precision = DataType::kFloat;
144+
std::vector<DataType> input_dtypes;
145145
bool disable_tf32 = false;
146146
bool refit = false;
147147
bool debug = false;

Diff for: py/trtorch/csrc/trtorch_py.cpp

+3
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ PYBIND11_MODULE(_C, m) {
176176
.value("half", DataType::kHalf, "16 bit floating point number")
177177
.value("float16", DataType::kHalf, "16 bit floating point number")
178178
.value("int8", DataType::kChar, "8 bit integer number")
179+
.value("int32", DataType::kInt32, "32 bit integer number")
180+
.value("bool", DataType::kChar, "Boolean value")
179181
.export_values();
180182

181183
py::enum_<DeviceType>(m, "DeviceType", "Enum to specify device kinds to build TensorRT engines for")
@@ -242,6 +244,7 @@ PYBIND11_MODULE(_C, m) {
242244
.def("_get_calibrator_handle", &CompileSpec::getPTQCalibratorHandle, "[Internal] gets a handle from a calibrator")
243245
.def_readwrite("input_ranges", &CompileSpec::input_ranges)
244246
.def_readwrite("op_precision", &CompileSpec::op_precision)
247+
.def_readwrite("input_dtypes", &CompileSpec::input_dtypes)
245248
.def_readwrite("ptq_calibrator", &CompileSpec::ptq_calibrator)
246249
.def_readwrite("refit", &CompileSpec::refit)
247250
.def_readwrite("disable_tf32", &CompileSpec::disable_tf32)

0 commit comments

Comments
 (0)