Skip to content

Commit 3677fd3

Browse files
committed
feat: Automatically cast user inputs to inferred data type
- Add post-lowering pass to insert `aten::to` operators for Tensor inputs determined to require float or int inputs - Specifically, if the user provides an non-float input to a float-dtype input field and has `truncate_long_and_double=True`, a Torch-executed graph block will be inserted which casts that input to a float in-place. - This operation modifies user-provided tensors and provides a warning as such - Currently, the feature is only functional for Tensor inputs (not input signatures) and only casts to int and float types - if the input is specified as any other type, a cast will not be inserted - Modify compiler to extract inferred data types for each input - Add testing to ensure casts are inserted correctly and run in Torch
1 parent 2ef6c3a commit 3677fd3

File tree

4 files changed

+144
-24
lines changed

4 files changed

+144
-24
lines changed

core/compiler.cpp

+34-24
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ partitioning::GraphAndMapping BuildHybridGraph(
187187
return partitioning::stitch(&partitioning_ctx, block);
188188
}
189189

190-
void MapInputsAndDetermineDTypes(
190+
ir::TypeMap MapInputsAndDetermineDTypes(
191191
CompileSpec& cfg,
192192
std::shared_ptr<torch::jit::Graph>& g,
193193
ir::StaticParams& static_params,
@@ -197,6 +197,7 @@ void MapInputsAndDetermineDTypes(
197197
cfg.partitioning_info.collection_input_spec_map =
198198
ir::CollectionInputSpecMap(cfg.convert_info.collection_input_spec_map);
199199

200+
ir::TypeMap inferred_dtypes;
200201
auto collection_inputs = ir::get_collection_inputs(g, static_params);
201202
LOG_DEBUG(
202203
"In MapInputsAndDetermineDTypes, the g->inputs() size is "
@@ -239,34 +240,36 @@ void MapInputsAndDetermineDTypes(
239240
first_use_type_map[in][i] = {
240241
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
241242

242-
} else {
243-
if (util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) !=
244-
est_type_opt[i].value()) {
245-
std::stringstream ss;
246-
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
247-
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
248-
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
249-
ss << est_type_opt[i].value() << std::endl;
250-
ss << "The compiler is going to use the user setting "
251-
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
252-
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
253-
ss << "compatibility with PyTorch's data type convention is required.\n";
254-
ss << "If you do indeed see errors at runtime either:\n";
255-
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
256-
ss << "- Disable partial compilation by setting require_full_compilation to True";
257-
auto warn_str = ss.str();
258-
LOG_WARNING(warn_str);
259-
// Overwrite type map with user settings
260-
first_use_type_map[in][i] = {
261-
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
262-
}
243+
} else if (
244+
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype) !=
245+
est_type_opt[i].value()) {
246+
std::stringstream ss;
247+
ss << "For input " << in->debugName() << ", found user specified input dtype as ";
248+
ss << cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
249+
ss << ", however when inspecting the graph, the input type expected was inferred to be ";
250+
ss << est_type_opt[i].value() << std::endl;
251+
ss << "The compiler is going to use the user setting "
252+
<< cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype;
253+
ss << "\nThis conflict may cause an error at runtime due to partial compilation being enabled and therefore\n";
254+
ss << "compatibility with PyTorch's data type convention is required.\n";
255+
ss << "If you do indeed see errors at runtime either:\n";
256+
ss << "- Remove the dtype spec for " << in->debugName() << std::endl;
257+
ss << "- Disable partial compilation by setting require_full_compilation to True";
258+
auto warn_str = ss.str();
259+
LOG_WARNING(warn_str);
260+
// Overwrite type map with user settings
261+
first_use_type_map[in][i] = {
262+
util::TRTDataTypeToScalarType(cfg.convert_info.collection_input_spec_map.find(in)->second[i].dtype)};
263263
}
264264
} else {
265265
// The user defined the type so no changes are necessary
266266
}
267+
268+
// Insert entry for Value pointer and determined ScalarType
269+
inferred_dtypes.insert({in, c10::optional<c10::ScalarType>(util::TRTDataTypeToScalarType(spec[i].dtype))});
267270
}
268271
}
269-
// }
272+
return inferred_dtypes;
270273
}
271274

272275
std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod, std::string method_name, CompileSpec cfg) {
@@ -307,7 +310,14 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
307310
// Infer the type of an input from the weights of the calculation
308311
auto first_use_types = ir::get_block_first_calc_dtypes_opt_collection(g->block());
309312

310-
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
313+
// Extract map of IValue to DType
314+
auto type_map = MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
315+
316+
// Use dtype map to autocast inputs to the correct type
317+
if (cfg.partitioning_info.enabled && cfg.partitioning_info.truncate_long_and_double) {
318+
lowering::AutocastInputs(g, type_map, cfg.lower_info.getGPUDeviceString());
319+
}
320+
311321
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
312322
auto outputIsCollection = conversion::OutputIsCollection(g->block());
313323
if (cfg.partitioning_info.enabled &&

core/lowering/lowering.cpp

+50
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,56 @@ void LowerBlock(torch::jit::Block* b) {
2626
DropUnusedNodes(b);
2727
}
2828

29+
void AutocastInputs(std::shared_ptr<torch::jit::Graph>& g, ir::TypeMap input_type_map, std::string target_device_name) {
30+
// For each graph input, determine if it can be autocasted
31+
for (int i = 0; i < g->inputs().size(); i++) {
32+
auto input = g->inputs()[i];
33+
34+
// Autocasted inputs must be Tensor-type
35+
if (input->type()->isSubtypeOf(c10::TensorType::get())) {
36+
auto dtype_input = input_type_map.find(input);
37+
38+
// Ensure the data type to be casted to exists in the type map
39+
if (dtype_input == input_type_map.end() || !dtype_input->second) {
40+
LOG_DEBUG("No inferred input dtype for tensor " << input->debugName() << ", skipping autocast");
41+
continue;
42+
}
43+
44+
auto dtype = dtype_input->second.value();
45+
// Currently, we do not autocast inputs for which the determined type is not int or float
46+
if (!(dtype == at::kFloat || dtype == at::kInt)) {
47+
continue;
48+
}
49+
50+
LOG_DEBUG("Inserting aten::to casting " << input->debugName() << " to dtype " << dtype);
51+
52+
// Generate cast node sending input tensors to the inferred or specified datatype (float or int)
53+
auto const_type = g->insertConstant(dtype);
54+
auto const_false = g->insertConstant(0);
55+
const_false->setType(torch::jit::BoolType::get());
56+
auto cuda = g->insertConstant(target_device_name);
57+
cuda->setType(torch::jit::DeviceObjType::get());
58+
auto none_val = g->insertNode(g->createNone())->output();
59+
auto cast_node = g->create(torch::jit::aten::to, {input, cuda, const_type, const_false, const_false, none_val});
60+
61+
// Replace all uses of the original tensor with that of the casted tensor
62+
g->prependNode(cast_node);
63+
input->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
64+
65+
// Mark the cast node to run in PyTorch for ease of casting
66+
LOG_GRAPH("Marking autocast node " << util::node_info(cast_node) << " to run in PyTorch");
67+
cast_node->i_(c10::Symbol::attr("to_compile"), (int64_t) false);
68+
}
69+
}
70+
71+
LOG_WARNING(
72+
"Input tensors to this Torch-TRT engine may have their data types in-place modified "
73+
<< "if the type does not match the determined required type for the model. To disable this "
74+
<< "automatic casting, specify truncate_long_and_double=False");
75+
76+
LOG_GRAPH("Graph after Autocast: " << *g);
77+
}
78+
2979
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params, LowerInfo lower_info) {
3080
torch::jit::EliminateRedundantGuards(g);
3181
torch::jit::RemoveListMutation(g);

core/lowering/lowering.h

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ struct LowerInfo {
2727

2828
void LowerBlock(torch::jit::Block* b);
2929
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info);
30+
void AutocastInputs(std::shared_ptr<torch::jit::Graph>& g, ir::TypeMap input_type_map, std::string target_device_name);
3031
torch::jit::Module LowerModule(
3132
const torch::jit::Module& mod,
3233
std::string method_name,

tests/core/partitioning/test_type_auto_conversion.cpp

+59
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
#include <string>
2+
#include "core/ir/ir.h"
3+
#include "core/lowering/lowering.h"
24
#include "core/partitioning/partitioning.h"
35
#include "core/util/trt_util.h"
46
#include "gtest/gtest.h"
@@ -107,3 +109,60 @@ TEST(Partitioning, ImplicitAutoConversionCorrectly) {
107109
}
108110
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[1], 2));
109111
}
112+
113+
TEST(Partitioning, AutoCastingInputIntsFloatsCorrectly) {
114+
const auto graph = R"IR(
115+
graph(%x.1 : Tensor,
116+
%y.1 : Tensor):
117+
%k.1 : int = prim::Constant[value=1]() # examples/custom_converters/toy_model.py:38:12
118+
%3 : int = prim::Constant[value=2]() # examples/custom_converters/toy_model.py:40:13
119+
%x.5 : Tensor = aten::add_(%x.1, %y.1, %k.1) # examples/custom_converters/toy_model.py:39:8
120+
%23 : Tensor = aten::mul(%y.1, %3) # <string>:3:9
121+
%x.9 : Tensor = aten::add(%x.5, %23, %k.1) # examples/custom_converters/toy_model.py:40:8
122+
%x.13 : Tensor = aten::add(%x.9, %k.1, %k.1) # examples/custom_converters/toy_model.py:41:8
123+
%x.17 : Tensor = aten::sub(%x.13, %k.1, %k.1) # examples/custom_converters/toy_model.py:42:8
124+
%x.21 : Tensor = aten::add(%x.17, %k.1, %k.1) # examples/custom_converters/toy_model.py:43:8
125+
%x.25 : Tensor = aten::sub(%x.21, %k.1, %k.1) # examples/custom_converters/toy_model.py:44:8
126+
return (%x.25))IR";
127+
128+
auto g = std::make_shared<torch::jit::Graph>();
129+
torch::jit::parseIR(graph, g.get(), true);
130+
131+
torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
132+
partitioning_info.enabled = true;
133+
partitioning_info.forced_fallback_operators = {"aten::expand"};
134+
partitioning_info.truncate_long_and_double = true;
135+
std::vector<torch_tensorrt::core::ir::Input> inputs;
136+
137+
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));
138+
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));
139+
140+
std::unordered_map<const torch::jit::Value*, std::vector<torch_tensorrt::core::ir::Input>> inputs_map;
141+
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types;
142+
inputs_map.insert({g->inputs()[0], {inputs[0]}});
143+
input_types.insert({g->inputs()[0], {{at::kFloat}}});
144+
inputs_map.insert({g->inputs()[1], {inputs[1]}});
145+
input_types.insert({g->inputs()[1], {{at::kInt}}});
146+
147+
partitioning_info.collection_input_spec_map = inputs_map;
148+
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
149+
ctx.input_types_map = input_types;
150+
151+
// Generate map of input Value * to dtype
152+
torch_tensorrt::core::partitioning::populateInputIValues(&ctx);
153+
torch_tensorrt::core::ir::TypeMap dtype_map;
154+
dtype_map.insert({g->inputs()[0], c10::optional<c10::ScalarType>(at::kFloat)});
155+
dtype_map.insert({g->inputs()[1], c10::optional<c10::ScalarType>(at::kInt)});
156+
157+
torch_tensorrt::core::lowering::AutocastInputs(g, dtype_map, "cuda");
158+
torch_tensorrt::core::partitioning::partition(&ctx);
159+
auto segmented_blocks = ctx.partitioned_blocks.begin()->second;
160+
161+
for (auto& seg_block : segmented_blocks) {
162+
LOG_DEBUG(seg_block << " cur seg block");
163+
}
164+
165+
// Ensure the first segmented block is a Torch block containing 2 casts
166+
ASSERT_TRUE(segmented_blocks[0].target() == torch_tensorrt::core::partitioning::SegmentedBlock::kTorch);
167+
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[0], 2));
168+
}

0 commit comments

Comments
 (0)