Skip to content

Commit 835abf0

Browse files
authored
Merge pull request #1551 from gs-olive/autocast
feat: Add option to specify int64 as an Input dtype
2 parents dc570e4 + 4282c06 commit 835abf0

19 files changed

+290
-54
lines changed

core/compiler.cpp

+53-31
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ partitioning::GraphAndMapping BuildHybridGraph(
187187
return partitioning::stitch(&partitioning_ctx, block);
188188
}
189189

190-
void MapInputsAndDetermineDTypes(
190+
ir::TypeMap MapInputsAndDetermineDTypes(
191191
CompileSpec& cfg,
192192
std::shared_ptr<torch::jit::Graph>& g,
193193
ir::StaticParams& static_params,
@@ -197,6 +197,7 @@ void MapInputsAndDetermineDTypes(
197197
cfg.partitioning_info.collection_input_spec_map =
198198
ir::CollectionInputSpecMap(cfg.convert_info.collection_input_spec_map);
199199

200+
ir::TypeMap inferred_dtypes;
200201
auto collection_inputs = ir::get_collection_inputs(g, static_params);
201202
LOG_DEBUG(
202203
"In MapInputsAndDetermineDTypes, the g->inputs() size is "
@@ -218,13 +219,13 @@ void MapInputsAndDetermineDTypes(
218219
LOG_INFO(
219220
"Since input type is not explicitly defined, infering using first tensor calculation\n Inferred input "
220221
<< in->debugName() << " has type " << est_type_opt[i].value());
221-
spec[i].dtype = util::ScalarTypeToTRTDataType(est_type_opt[i].value());
222+
spec[i].dtype = est_type_opt[i].value();
222223
} else if (!est_type_opt[i] && !spec[i].dtype_is_user_defined) {
223224
// If we cannot calculate the type and the user did not define the type, then default to FP32
224225
LOG_WARNING(
225226
"Cannot infer input type from calcuations in graph for input "
226227
<< in->debugName() << ". Assuming it is Float32. If not, specify input type explicity");
227-
spec[i].dtype = nvinfer1::DataType::kFLOAT;
228+
spec[i].dtype = at::kFloat;
228229
} else if (spec[i].dtype_is_user_defined && cfg.partitioning_info.enabled) {
229230
if (!est_type_opt[i]) {
230231
LOG_INFO("Cannot infer input tensor dtype in graph, compiler is going to use the user setting");
@@ -236,37 +237,35 @@ void MapInputsAndDetermineDTypes(
236237
auto warn_str = ss.str();
237238
LOG_WARNING(warn_str);
238239
// Overwrite type map with user settings
239-
first_use_type_map[in][i] = {
240-
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
241-
242-
} else {
243-
if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) !=
244-
est_type_opt[i].value()) {
245-
std::stringstream ss;
246-
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
247-
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
248-
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
249-
ss << est_type_opt[i].value() << std::endl;
250-
ss << "The compiler is going to use the user setting "
251-
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
252-
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
253-
ss << "compatibility with PyTorch's data type convention is required.\n";
254-
ss << "If you do indeed see errors at runtime either:\n";
255-
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
256-
ss << "- Disable partial compilation by setting require_full_compilation to True";
257-
auto warn_str = ss.str();
258-
LOG_WARNING(warn_str);
259-
// Overwrite type map with user settings
260-
first_use_type_map[in][i] = {
261-
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
262-
}
240+
first_use_type_map[in][i] = {cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype};
241+
242+
} else if (cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype != est_type_opt[i].value()) {
243+
std::stringstream ss;
244+
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
245+
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
246+
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
247+
ss << est_type_opt[i].value() << std::endl;
248+
ss << "The compiler is going to use the user setting "
249+
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
250+
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
251+
ss << "compatibility with PyTorch's data type convention is required.\n";
252+
ss << "If you do indeed see errors at runtime either:\n";
253+
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
254+
ss << "- Disable partial compilation by setting require_full_compilation to True";
255+
auto warn_str = ss.str();
256+
LOG_WARNING(warn_str);
257+
// Overwrite type map with user settings
258+
first_use_type_map[in][i] = {cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype};
263259
}
264260
} else {
265261
// The user defined the type so no changes are necessary
266262
}
263+
264+
// Insert entry for Value pointer and determined ScalarType
265+
inferred_dtypes.insert({in, {spec[i].dtype}});
267266
}
268267
}
269-
// }
268+
return inferred_dtypes;
270269
}
271270

272271
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
@@ -284,6 +283,15 @@ std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::
284283

285284
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
286285

286+
// Ensure none of the specified types are of acceptable input types incompatible with TRT
287+
// Currently, only at::kLong is an acceptable, though TRT-incompatible type
288+
for (auto value_to_dtypes : first_use_types) {
289+
for (auto dtype : value_to_dtypes.second) {
290+
TORCHTRT_CHECK(
291+
!dtype || dtype.value() != at::kLong, "Cannot specify Int64 input for a model fully compiled in TRT");
292+
}
293+
}
294+
287295
auto engine = conversion::ConvertBlockToEngine(g->block(), cfg.convert_info, static_params);
288296

289297
return engine;
@@ -307,10 +315,24 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
307315
// Infer the type of an input from the weights of the calculation
308316
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());
309317

310-
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
318+
// Extract map of IValue to DType
319+
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
320+
321+
// Check whether any of the input types are Long
322+
bool user_requested_long = false;
323+
for (auto dtype : type_map) {
324+
user_requested_long |= dtype.second && (dtype.second.value() == at::kLong);
325+
}
326+
327+
// Use dtype map to autocast Tensor-type inputs to Long dtype as necessary
328+
if (cfg.partitioning_info.enabled && cfg.partitioning_info.truncate_long_and_double && user_requested_long) {
329+
auto casts_inserted = lowering::AutocastLongInputs(g, type_map, cfg.lower_info.getGPUDeviceString());
330+
user_requested_long &= (casts_inserted > 0);
331+
}
332+
311333
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
312334
auto outputIsCollection = conversion::OutputIsCollection(g->block());
313-
if (cfg.partitioning_info.enabled &&
335+
if (cfg.partitioning_info.enabled && !user_requested_long &&
314336
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
315337
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) &&
316338
!outputIsCollection) {
@@ -320,7 +342,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
320342
if (cfg.partitioning_info.enabled &&
321343
(!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
322344
cfg.partitioning_info.forced_fallback_operators.size() == 0 && isBlockConvertible) ||
323-
outputIsCollection)) {
345+
outputIsCollection || user_requested_long)) {
324346
auto graph_and_mapping = BuildHybridGraph(new_mod, g->block(), cfg, static_params, first_use_types);
325347
new_g = graph_and_mapping.first;
326348
// renaming the input name of graph after fallback to ensure pytorch deserialize it correctly

core/conversion/conversion.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ void AddInputs(ConversionCtx* ctx, c10::ArrayRef<const torch::jit::Value*> input
183183
"Adding Input " << in->debugName() << " (named: " << name << "): " << spec
184184
<< " in engine (conversion.AddInputs)");
185185

186-
auto trt_in = ctx->net->addInput(name.c_str(), spec.dtype, spec.input_shape);
186+
auto trt_in = ctx->net->addInput(name.c_str(), util::ScalarTypeToTRTDataType(spec.dtype), spec.input_shape);
187187
TORCHTRT_CHECK(trt_in, "Failed to add input node: " << in->debugName() << " (conversion.AddInputs)");
188188
trt_in->setAllowedFormats(1U << static_cast<int>(spec.format));
189189

core/ir/Input.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ bool valid_input_dtype(nvinfer1::DataType dtype) {
7171

7272
Input::Input(
7373
std::vector<int64_t> shape,
74-
nvinfer1::DataType dtype,
74+
at::ScalarType dtype,
7575
nvinfer1::TensorFormat format,
7676
bool dtype_is_user_defined) {
7777
if (shape.size() > 5) {
@@ -84,10 +84,10 @@ Input::Input(
8484
input_shape = util::toDims(shape);
8585
input_is_dynamic = false;
8686

87-
TORCHTRT_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
87+
TORCHTRT_CHECK(valid_input_dtype(util::ScalarTypeToTRTDataType(dtype)), "Unsupported input data type: " << dtype);
8888
this->dtype = dtype;
8989
TORCHTRT_CHECK(
90-
valid_dtype_format_combo(dtype, format),
90+
valid_dtype_format_combo(util::ScalarTypeToTRTDataType(dtype), format),
9191
"Unsupported combination of dtype and tensor format: ("
9292
<< dtype << ", " << format
9393
<< "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");
@@ -99,7 +99,7 @@ Input::Input(
9999
std::vector<int64_t> min_shape,
100100
std::vector<int64_t> opt_shape,
101101
std::vector<int64_t> max_shape,
102-
nvinfer1::DataType dtype,
102+
at::ScalarType dtype,
103103
nvinfer1::TensorFormat format,
104104
bool dtype_is_user_defined) {
105105
if (min_shape.size() > 5 || opt_shape.size() > 5 || max_shape.size() > 5) {
@@ -137,10 +137,10 @@ Input::Input(
137137

138138
input_shape = util::toDims(dyn_shape);
139139

140-
TORCHTRT_CHECK(valid_input_dtype(dtype), "Unsupported input data type: " << dtype);
140+
TORCHTRT_CHECK(valid_input_dtype(util::ScalarTypeToTRTDataType(dtype)), "Unsupported input data type: " << dtype);
141141
this->dtype = dtype;
142142
TORCHTRT_CHECK(
143-
valid_dtype_format_combo(dtype, format),
143+
valid_dtype_format_combo(util::ScalarTypeToTRTDataType(dtype), format),
144144
"Unsupported combination of dtype and tensor format: ("
145145
<< dtype << ", " << format
146146
<< "), Torch-TensorRT only supports contiguous format (NCHW) except with input type Float32 where channel last (NHWC) is also supported");

core/ir/ir.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,17 @@ struct Input : torch::CustomClassHolder {
2929
Input(){};
3030
Input(
3131
std::vector<int64_t> shape,
32-
nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
32+
at::ScalarType dtype = at::kFloat,
3333
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR,
3434
bool dtype_is_user_defined = false);
3535
Input(
3636
std::vector<int64_t> min_shape,
3737
std::vector<int64_t> opt_shape,
3838
std::vector<int64_t> max_shape,
39-
nvinfer1::DataType dtype = nvinfer1::DataType::kFLOAT,
39+
at::ScalarType dtype = at::kFloat,
4040
nvinfer1::TensorFormat format = nvinfer1::TensorFormat::kLINEAR,
4141
bool dtype_is_used_defined = false);
42+
4243
friend std::ostream& operator<<(std::ostream& os, const Input& input);
4344

4445
bool input_is_dynamic = false;
@@ -47,7 +48,7 @@ struct Input : torch::CustomClassHolder {
4748
nvinfer1::Dims min;
4849
nvinfer1::Dims max;
4950
nvinfer1::Dims opt;
50-
nvinfer1::DataType dtype;
51+
at::ScalarType dtype;
5152
nvinfer1::TensorFormat format;
5253
int id;
5354
};

core/lowering/lowering.cpp

+69
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,75 @@ void LowerBlock(torch::jit::Block* b) {
2626
DropUnusedNodes(b);
2727
}
2828

29+
int AutocastLongInputs(
30+
std::shared_ptr<torch::jit::Graph>& g,
31+
ir::TypeMap input_type_map,
32+
std::string target_device_name) {
33+
int num_autocasts = 0;
34+
// For each graph input, determine if it can be autocasted
35+
for (int i = 0; i < g->inputs().size(); i++) {
36+
auto input = g->inputs()[i];
37+
38+
// Autocasted inputs must be Tensor-type
39+
if (input->type()->isSubtypeOf(c10::TensorType::get())) {
40+
auto dtype_input = input_type_map.find(input);
41+
42+
// Ensure the data type to be casted to exists in the type map
43+
if (dtype_input == input_type_map.end() || !dtype_input->second) {
44+
LOG_DEBUG("No inferred input dtype for tensor " << input->debugName() << ", skipping autocast");
45+
continue;
46+
}
47+
48+
auto dtype = dtype_input->second.value();
49+
// Currently, we do not autocast inputs for which the determined type is not long
50+
if (dtype != at::kLong) {
51+
LOG_DEBUG(
52+
"Skipping autocast for tensor " << input->debugName() << ", since its dtype is " << dtype
53+
<< " and not at::kLong");
54+
continue;
55+
}
56+
57+
LOG_DEBUG("Inserting aten::to casting " << input->debugName() << " to dtype " << dtype);
58+
59+
// Generate cast node sending input tensors to the inferred or specified datatype (long)
60+
torch::jit::Value *const_false, *cuda, *none_val;
61+
if (num_autocasts == 0) {
62+
// Only generate constants once and reuse for all autocasts
63+
const_false = g->insertConstant(0);
64+
const_false->setType(torch::jit::BoolType::get());
65+
cuda = g->insertConstant(target_device_name);
66+
cuda->setType(torch::jit::DeviceObjType::get());
67+
none_val = g->insertNode(g->createNone())->output();
68+
}
69+
70+
auto const_type = g->insertConstant(dtype);
71+
auto cast_node = g->create(torch::jit::aten::to, {input, cuda, const_type, const_false, const_false, none_val});
72+
73+
// Replace all uses of the original tensor with that of the casted tensor
74+
g->prependNode(cast_node);
75+
input->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
76+
77+
// Mark the cast node to run in PyTorch for ease of casting
78+
LOG_GRAPH("Marking autocast node " << util::node_info(cast_node) << " to run in PyTorch");
79+
cast_node->i_(c10::Symbol::attr("to_compile"), (int64_t) false);
80+
num_autocasts++;
81+
}
82+
}
83+
84+
LOG_GRAPH("Inserted " << num_autocasts << " autocasts");
85+
86+
if (num_autocasts > 0) {
87+
LOG_WARNING(
88+
"Data types for input tensors have been modified by inserting "
89+
<< "aten::to operations which cast INT64 inputs to INT32. "
90+
<< "To disable this, please recompile using INT32 inputs");
91+
92+
LOG_GRAPH("Graph after Autocast: " << *g);
93+
}
94+
95+
return num_autocasts;
96+
}
97+
2998
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params, LowerInfo lower_info) {
3099
torch::jit::EliminateRedundantGuards(g);
31100
torch::jit::RemoveListMutation(g);

core/lowering/lowering.h

+4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ struct LowerInfo {
2727

2828
void LowerBlock(torch::jit::Block* b);
2929
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info);
30+
int AutocastLongInputs(
31+
std::shared_ptr<torch::jit::Graph>& g,
32+
ir::TypeMap input_type_map,
33+
std::string target_device_name);
3034
torch::jit::Module LowerModule(
3135
const torch::jit::Module& mod,
3236
std::string method_name,

core/partitioning/segmentedblock/SegmentedBlock.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,13 @@ std::vector<ir::Input> SegmentedBlock::construct_inputs_spec() const {
6262
if (min_shapes_.size() == opt_shapes_.size() && opt_shapes_.size() == max_shapes_.size()) {
6363
for (uint64_t i = 0; i < opt_shapes_.size(); i++) {
6464
auto in = ir::Input(min_shapes_[i], opt_shapes_[i], max_shapes_[i]);
65-
in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]);
65+
in.dtype = in_types_[i];
6666
inputs.push_back(in);
6767
}
6868
} else {
6969
for (uint64_t i = 0; i < opt_shapes_.size(); i++) {
7070
auto in = ir::Input(opt_shapes_[i]);
71-
in.dtype = util::ScalarTypeToTRTDataType(in_types_[i]);
71+
in.dtype = in_types_[i];
7272
inputs.push_back(in);
7373
}
7474
}

core/partitioning/shape_analysis.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -303,10 +303,10 @@ void getSegmentsOutputByRunning(
303303
"Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled");
304304
} else if (partitioning_info.truncate_long_and_double && t == at::kLong) {
305305
cur_ivalue = cur_ivalue.toTensor().to(at::kInt);
306-
LOG_WARNING("Truncating graph input type from at::kLong to at::kInt");
306+
LOG_WARNING("Truncating intermediate graph input type from at::kLong to at::kInt");
307307
} else if (partitioning_info.truncate_long_and_double && t == at::kDouble) {
308308
cur_ivalue = cur_ivalue.toTensor().to(at::kFloat);
309-
LOG_WARNING("Truncating graph input type from at::kDouble to at::kFloat");
309+
LOG_WARNING("Truncating intermediate graph input type from at::kDouble to at::kFloat");
310310
}
311311

312312
c10::optional<nvinfer1::DataType> dtype = util::optTypeMetaToTRTDataType(cur_ivalue.toTensor().dtype());

core/util/trt_util.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,7 @@ const std::unordered_map<at::ScalarType, nvinfer1::DataType>& get_at_trt_type_ma
251251
{at::kFloat, nvinfer1::DataType::kFLOAT},
252252
{at::kHalf, nvinfer1::DataType::kHALF},
253253
{at::kInt, nvinfer1::DataType::kINT32},
254+
{at::kLong, nvinfer1::DataType::kINT32},
254255
{at::kChar, nvinfer1::DataType::kINT8},
255256
{at::kByte, nvinfer1::DataType::kINT8},
256257
{at::kBool, nvinfer1::DataType::kBOOL}};

cpp/include/torch_tensorrt/torch_tensorrt.h

+2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ class DataType {
5858
* ex. torch_tensorrt::DataType type = DataType::kFloat;
5959
*/
6060
enum Value : int8_t {
61+
/// INT64
62+
kLong,
6163
/// FP32
6264
kFloat,
6365
/// FP16

0 commit comments

Comments
 (0)