Skip to content

Commit 14ed6dd

Browse files
committed
feat: Add option to specify int64 as an Input dtype
- Rework `Input` paradigm to be based on `at::ScalarType` as opposed to the previous `nvinfer1::DataType`, allowing a larger representation space of data types - When paired with `truncate_long_and_double`, insert casts to ensure Torch engines using Int64 tensors receive the correct types, and TensorRT engines operating on those tensors receive downcasted Int32 versions thereof - Add Torch block at the beginning of model graph to prepare types of input tensors for forthcoming engines in sequence - Automatically follow internal tensor types to abstract away the different internal engines used (Torch/TensorRT) from the user - Provide a framework for streamlined addition of other data types, including `torch.double` as valid input types - Improve error checking to ensure model compilation and behavior is as documented. For example, disallow specification of Long type input if the engine is required to be converted entirely to TRT - Known Limitations: - Specifying `dtype=torch.long` on an `Input` in an `input_signature` is not supported currently and will throw an error before model compilation when used with the Python API - While Torch may output Int64 tensors from the overall model, Torch-TRT currently can only output Int32 tensors for models using TRT, as there is not a mechanism in place for differentiating intermediate blocks from final/beginning blocks in the graph - Torch-TRT will almost definitely alter the data type of the input tensor, in-place, if `dtype=torch.long` is specified, and the returned result will be of type `torch.int32`
1 parent f43be5b commit 14ed6dd

File tree

17 files changed

+234
-53
lines changed

17 files changed

+234
-53
lines changed

core/compiler.cpp

+53-31
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ partitioning::GraphAndMapping BuildHybridGraph(
187187
return partitioning::stitch(&partitioning_ctx, block);
188188
}
189189

190-
void MapInputsAndDetermineDTypes(
190+
ir::TypeMap MapInputsAndDetermineDTypes(
191191
CompileSpec& cfg,
192192
std::shared_ptr<torch::jit::Graph>& g,
193193
ir::StaticParams& static_params,
@@ -197,6 +197,7 @@ void MapInputsAndDetermineDTypes(
197197
cfg.partitioning_info.collection_input_spec_map =
198198
ir::CollectionInputSpecMap(cfg.convert_info.collection_input_spec_map);
199199

200+
ir::TypeMap inferred_dtypes;
200201
auto collection_inputs = ir::get_collection_inputs(g, static_params);
201202
LOG_DEBUG(
202203
"In MapInputsAndDetermineDTypes, the g->inputs() size is "
@@ -218,13 +219,13 @@ void MapInputsAndDetermineDTypes(
218219
LOG_INFO(
219220
"Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input "
220221
<< in->debugName() << " has type " << est_type_opt[i].value());
221-
spec[i].dtype = util::ScalarTypeToTRTDataType(est_type_opt[i].value());
222+
spec[i].dtype = est_type_opt[i].value();
222223
} else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined) {
223224
// If we cannot calculate the type and the user did not define the type, then default to FP32
224225
LOG_WARNING(
225226
"Cannot infer input type from calcuations in graph for input "
226227
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
227-
spec[i].dtype = nvinfer1::DataType::kFLOAT;
228+
spec[i].dtype = at::kFloat;
228229
} else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) {
229230
if (!est_type_opt[i]) {
230231
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
@@ -236,37 +237,35 @@ void MapInputsAndDetermineDTypes(
236237
auto warn_str = ss.str();
237238
LOG_WARNING(warn_str);
238239
// Overwrite type map with user settings
239-
first_use_type_map[in][i] = {
240-
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
241-
242-
} else {
243-
if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) !=
244-
est_type_opt[i].value()) {
245-
std::stringstream ss;
246-
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
247-
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
248-
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
249-
ss << est_type_opt[i].value() << std::endl;
250-
ss << "The compiler is going to use the user setting "
251-
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
252-
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
253-
ss << "compatibility with PyTorch's data type convention is required.\n";
254-
ss << "If you do indeed see errors at runtime either:\n";
255-
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
256-
ss << "- Disable partial compilation by setting require_full_compilation to True";
257-
auto warn_str = ss.str();
258-
LOG_WARNING(warn_str);
259-
// Overwrite type map with user settings
260-
first_use_type_map[in][i] = {
261-
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
262-
}
240+
first_use_type_map[in][i] = {cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype};
241+
242+
} else if (cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype != est_type_opt[i].value()) {
243+
std::stringstream ss;
244+
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
245+
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
246+
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
247+
ss << est_type_opt[i].value() << std::endl;
248+
ss << "The compiler is going to use the user setting "
249+
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
250+
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
251+
ss << "compatibility with PyTorch's data type convention is required.\n";
252+
ss << "If you do indeed see errors at runtime either:\n";
253+
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
254+
ss << "- Disable partial compilation by setting require_full_compilation to True";
255+
auto warn_str = ss.str();
256+
LOG_WARNING(warn_str);
257+
// Overwrite type map with user settings
258+
first_use_type_map[in][i] = {cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype};
263259
}
264260
} else {
265261
// The user defined the type so no changes are necessary
266262
}
263+
264+
// Insert entry for Value pointer and determined ScalarType
265+
inferred_dtypes.insert({in, {spec[i].dtype}});
267266
}
268267
}
269-
// }
268+
return inferred_dtypes;
270269
}
271270

272271
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
@@ -284,6 +283,15 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
284283

285284
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
286285

286+
// Ensure none of the specified types are of acceptable input types incompatible with TRT
287+
// Currently, only at::kLong is an acceptable, though TRT-incompatible type
288+
for (auto value_to_dtypes : first_use_types) {
289+
for (auto dtype : value_to_dtypes.second) {
290+
TORCHTRT_CHECK(
291+
!dtype || dtype.value() != at::kLong, "Cannot specify Int64 input for a model fully compiled in TRT");
292+
}
293+
}
294+
287295
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
288296

289297
return engine;
@@ -307,10 +315,24 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
307315
// Infer the type of an input from the weights of the calculation
308316
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());
309317

310-
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
318+
// Extract map of IValue to DType
319+
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
320+
321+
// Check whether any of the input types are Long
322+
bool user_requested_long = false;
323+
for (auto dtype : type_map) {
324+
user_requested_long |= dtype.second && (dtype.second.value() == at::kLong);
325+
}
326+
327+
// Use dtype map to autocast Tensor-type inputs to Long dtype as necessary
328+
if (cfg.partitioning_info.enabled && cfg.partitioning_info.truncate_long_and_double && user_requested_long) {
329+
auto casts_inserted = lowering::AutocastLongInputs(g, type_map, cfg.lower_info.getGPUDeviceString());
330+
user_requested_long &= (casts_inserted > 0);
331+
}
332+
311333
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
312334
auto outputIsCollection = conversion::OutputIsCollection(g->block());
313-
if (cfg.partitioning_info.enabled &&
335+
if (cfg.partitioning_info.enabled && !user_requested_long &&
314336
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
315337
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) &&
316338
!outputIsCollection) {
@@ -320,7 +342,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
320342
if (cfg.partitioning_info.enabled &&
321343
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
322344
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
323-
outputIsCollection)) {
345+
outputIsCollection || user_requested_long)) {
324346
auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types);
325347
new_g = graph_and_mapping.first;
326348
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly

core/conversion/conversion.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ void AddInputs(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> input
183183
"Adding Input " << in->debugName() << " (named: " << name << "): " << spec
184184
<< " in engine (conversion.AddInputs)");
185185

186-
auto trt_in = ctx->net->addInput(name.c_str(), spec.dtype, spec.input_shape);
186+
auto trt_in = ctx->net->addInput(name.c_str(), util::ScalarTypeToTRTDataType(spec.dtype), spec.input_shape);
187187
TORCHTRT_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)");
188188
trt_in->setAllowedFormats(1U << static_cast<int>(spec.format));
189189

core/ir/Input.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ bool valid_input_dtype(nvinfer1::DataType dtype) {
7171

7272
Input::Input(
7373
std::vector<int64_t> shape,
74-
nvinfer1::DataType dtype,
74+
at::ScalarType dtype,
7575
nvinfer1::TensorFormat format,
7676
bool dtype_is_user_defined) {
7777
if (shape.size() > 5) {
@@ -84,10 +84,10 @@ Input::Input(
8484
input_shape = util::toDims(shape);
8585
input_is_dynamic = false;
8686

87-
TORCHTRT_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
87+
TORCHTRT_CHECK(valid_input_dtype(util::ScalarTypeToTRTDataType(dtype)), "Unsupported input data type: " << dtype);
8888
this->dtype = dtype;
8989
TORCHTRT_CHECK(
90-
valid_dtype_format_combo(dtype, format),
90+
valid_dtype_format_combo(util::ScalarTypeToTRTDataType(dtype), format),
9191
"Unsupported combination of dtype and tensor format: ("
9292
<< dtype << ", " << format
9393
<< "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
@@ -99,7 +99,7 @@ Input::Input(
9999
std::vector<int64_t> min_shape,
100100
std::vector<int64_t> opt_shape,
101101
std::vector<int64_t> max_shape,
102-
nvinfer1::DataType dtype,
102+
at::ScalarType dtype,
103103
nvinfer1::TensorFormat format,
104104
bool dtype_is_user_defined) {
105105
if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) {
@@ -137,10 +137,10 @@ Input::Input(
137137

138138
input_shape = util::toDims(dyn_shape);
139139

140-
TORCHTRT_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
140+
TORCHTRT_CHECK(valid_input_dtype(util::ScalarTypeToTRTDataType(dtype)), "Unsupported input data type: " << dtype);
141141
this->dtype = dtype;
142142
TORCHTRT_CHECK(
143-
valid_dtype_format_combo(dtype, format),
143+
valid_dtype_format_combo(util::ScalarTypeToTRTDataType(dtype), format),
144144
"Unsupported combination of dtype and tensor format: ("
145145
<< dtype << ", " << format
146146
<< "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");

core/ir/ir.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,17 @@ struct Input : torch::CustomClassHolder {
2929
Input(){};
3030
Input(
3131
std::vector<int64_t> shape,
32-
nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
32+
at::ScalarType dtype = at::kFloat,
3333
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR,
3434
bool dtype_is_user_defined = false);
3535
Input(
3636
std::vector<int64_t> min_shape,
3737
std::vector<int64_t> opt_shape,
3838
std::vector<int64_t> max_shape,
39-
nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
39+
at::ScalarType dtype = at::kFloat,
4040
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR,
4141
bool dtype_is_used_defined = false);
42+
4243
friend std::ostream& operator<<(std::ostream& os, const Input& input);
4344

4445
bool input_is_dynamic = false;
@@ -47,7 +48,7 @@ struct Input : torch::CustomClassHolder {
4748
nvinfer1::Dims min;
4849
nvinfer1::Dims max;
4950
nvinfer1::Dims opt;
50-
nvinfer1::DataType dtype;
51+
at::ScalarType dtype;
5152
nvinfer1::TensorFormat format;
5253
int id;
5354
};

core/lowering/lowering.cpp

+57
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,63 @@ void LowerBlock(torch::jit::Block* b) {
2626
DropUnusedNodes(b);
2727
}
2828

29+
int AutocastLongInputs(
30+
std::shared_ptr<torch::jit::Graph>& g,
31+
ir::TypeMap input_type_map,
32+
std::string target_device_name) {
33+
int num_autocasts = 0;
34+
// For each graph input, determine if it can be autocasted
35+
for (int i = 0; i < g->inputs().size(); i++) {
36+
auto input = g->inputs()[i];
37+
38+
// Autocasted inputs must be Tensor-type
39+
if (input->type()->isSubtypeOf(c10::TensorType::get())) {
40+
auto dtype_input = input_type_map.find(input);
41+
42+
// Ensure the data type to be casted to exists in the type map
43+
if (dtype_input == input_type_map.end() || !dtype_input->second) {
44+
LOG_DEBUG("No inferred input dtype for tensor " << input->debugName() << ", skipping autocast");
45+
continue;
46+
}
47+
48+
auto dtype = dtype_input->second.value();
49+
// Currently, we do not autocast inputs for which the determined type is not long
50+
if (dtype != at::kLong) {
51+
continue;
52+
}
53+
54+
LOG_DEBUG("Inserting aten::to casting " << input->debugName() << " to dtype " << dtype);
55+
56+
// Generate cast node sending input tensors to the inferred or specified datatype (long)
57+
auto const_type = g->insertConstant(dtype);
58+
auto const_false = g->insertConstant(0);
59+
const_false->setType(torch::jit::BoolType::get());
60+
auto cuda = g->insertConstant(target_device_name);
61+
cuda->setType(torch::jit::DeviceObjType::get());
62+
auto none_val = g->insertNode(g->createNone())->output();
63+
auto cast_node = g->create(torch::jit::aten::to, {input, cuda, const_type, const_false, const_false, none_val});
64+
65+
// Replace all uses of the original tensor with that of the casted tensor
66+
g->prependNode(cast_node);
67+
input->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
68+
69+
// Mark the cast node to run in PyTorch for ease of casting
70+
LOG_GRAPH("Marking autocast node " << util::node_info(cast_node) << " to run in PyTorch");
71+
cast_node->i_(c10::Symbol::attr("to_compile"), (int64_t) false);
72+
num_autocasts++;
73+
}
74+
}
75+
76+
LOG_WARNING(
77+
"Input tensors to this Torch-TRT engine may have their data types in-place modified "
78+
<< "if the type does not match the determined required type for TRT. To disable this "
79+
<< "automatic casting, specify an Input dtype other than Long");
80+
81+
LOG_GRAPH("Graph after Autocast: " << *g);
82+
83+
return num_autocasts;
84+
}
85+
2986
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params, LowerInfo lower_info) {
3087
torch::jit::EliminateRedundantGuards(g);
3188
torch::jit::RemoveListMutation(g);

core/lowering/lowering.h

+4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ struct LowerInfo {
2727

2828
void LowerBlock(torch::jit::Block* b);
2929
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info);
30+
int AutocastLongInputs(
31+
std::shared_ptr<torch::jit::Graph>& g,
32+
ir::TypeMap input_type_map,
33+
std::string target_device_name);
3034
torch::jit::Module LowerModule(
3135
const torch::jit::Module& mod,
3236
std::string method_name,

core/partitioning/segmentedblock/SegmentedBlock.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@ std::vector<ir::Input> SegmentedBlock::construct_inputs_spec() const {
6262
if (min_shapes_.size() == opt_shapes_.size() && opt_shapes_.size() == max_shapes_.size()) {
6363
for (uint64_t i = 0; i < opt_shapes_.size(); i++) {
6464
auto in = ir::Input(min_shapes_[i], opt_shapes_[i], max_shapes_[i]);
65-
in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]);
65+
in.dtype = in_types_[i];
6666
inputs.push_back(in);
6767
}
6868
} else {
6969
for (uint64_t i = 0; i < opt_shapes_.size(); i++) {
7070
auto in = ir::Input(opt_shapes_[i]);
71-
in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]);
71+
in.dtype = in_types_[i];
7272
inputs.push_back(in);
7373
}
7474
}

core/partitioning/shape_analysis.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -266,10 +266,10 @@ void getSegmentsOutputByRunning(
266266
"Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled");
267267
} else if (partitioning_info.truncate_long_and_double && t == at::kLong) {
268268
cur_ivalue = cur_ivalue.toTensor().to(at::kInt);
269-
LOG_WARNING("Truncating graph input type from at::kLong to at::kInt");
269+
LOG_WARNING("Truncating intermediate graph input type from at::kLong to at::kInt");
270270
} else if (partitioning_info.truncate_long_and_double && t == at::kDouble) {
271271
cur_ivalue = cur_ivalue.toTensor().to(at::kFloat);
272-
LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat");
272+
LOG_WARNING("Truncating intermediate graph input type from at::kDouble to at::kFloat");
273273
}
274274
c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype());
275275
if (dtype == c10::nullopt) {

core/util/trt_util.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
251251
{at::kFloat, nvinfer1::DataType::kFLOAT},
252252
{at::kHalf, nvinfer1::DataType::kHALF},
253253
{at::kInt, nvinfer1::DataType::kINT32},
254+
{at::kLong, nvinfer1::DataType::kINT32},
254255
{at::kChar, nvinfer1::DataType::kINT8},
255256
{at::kBool, nvinfer1::DataType::kBOOL}};
256257
return at_trt_type_map;

cpp/include/torch_tensorrt/torch_tensorrt.h

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class DataType {
5858
* ex. torch_tensorrt::DataType type = DataType::kFloat;
5959
*/
6060
enum Value : int8_t {
61+
/// INT64
62+
kLong,
6163
/// FP32
6264
kFloat,
6365
/// FP16

cpp/src/types.cpp

+20-1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,25 @@ nvinfer1::DataType toTRTDataType(DataType value) {
8787
}
8888
}
8989

90+
at::ScalarType toAtDataType(DataType value) {
91+
switch (value) {
92+
case DataType::kChar:
93+
return at::kChar;
94+
case DataType::kHalf:
95+
return at::kHalf;
96+
case DataType::kInt:
97+
return at::kInt;
98+
case DataType::kLong:
99+
return at::kLong;
100+
case DataType::kBool:
101+
return at::kBool;
102+
case DataType::kFloat:
103+
case DataType::kUnknown:
104+
default:
105+
return at::kFloat;
106+
}
107+
}
108+
90109
nvinfer1::TensorFormat toTRTTensorFormat(TensorFormat value) {
91110
TORCHTRT_CHECK(!(value == TensorFormat::kUnknown), "Tensor format is unknown");
92111
switch (value) {
@@ -267,7 +286,7 @@ torch_tensorrt::core::ir::Input to_internal_input(Input& i) {
267286
i.min_shape,
268287
i.opt_shape,
269288
i.max_shape,
270-
toTRTDataType(i.dtype),
289+
toAtDataType(i.dtype),
271290
toTRTTensorFormat(i.format),
272291
!(i.dtype == DataType::kUnknown));
273292
}

0 commit comments

Comments
 (0)