Skip to content

Commit 86ff042

Browse files
authored
feat: support int64 <=> int32 auto conversion (#1407)
* feat: support int64 <=> int32 auto conversion Signed-off-by: Bo Wang <[email protected]> * fix: cover more cases Signed-off-by: Bo Wang <[email protected]> * test: add test cases for automatic int64<=>32 type conversion Signed-off-by: Bo Wang <[email protected]> * fix: fix input/output wrong indexing bug Signed-off-by: Bo Wang <[email protected]> * chore: fix typo and apply linting Signed-off-by: Bo Wang <[email protected]> Signed-off-by: Bo Wang <[email protected]>
1 parent 19e536a commit 86ff042

File tree

4 files changed

+200
-3
lines changed

4 files changed

+200
-3
lines changed

core/partitioning/partitioning.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -553,6 +553,7 @@ void partition(PartitioningCtx* ctx, ExampleIValues& example_tensor_map) {
553553
registerSegmentsOutputs(ctx, block);
554554

555555
// run shape analysis on each segmented block
556+
LOG_DEBUG("Running shape analysis for segmented graphs");
556557
runShapeAnalysis(ctx, block, example_tensor_map);
557558
}
558559
}

core/partitioning/shape_analysis.cpp

+88-3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include <queue>
12
#include "ATen/ATen.h"
23
#include "torch/csrc/jit/api/module.h"
34
#include "torch/csrc/jit/passes/constant_pooling.h"
@@ -65,6 +66,61 @@ std::unordered_map<const torch::jit::Value*, torch::jit::IValue> generateRandomI
6566
return ivalue_map;
6667
}
6768

69+
torch::jit::Node* getUpstreamCastNode(torch::jit::Value* val) {
70+
std::queue<torch::jit::Value*> q;
71+
q.push(val);
72+
std::unordered_set<torch::jit::Node*> visited;
73+
while (!q.empty()) {
74+
auto cur_val = q.front();
75+
q.pop();
76+
auto node = cur_val->node();
77+
if ((node->kind().toQualString() == std::string("aten::to")) &&
78+
((node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::IntType) ||
79+
(node->inputs()[2]->node()->output()->type()->kind() == torch::jit::TypeKind::IntType))) {
80+
return node;
81+
}
82+
if (node->kind() != torch::jit::prim::Constant && !visited.count(node)) {
83+
visited.insert(node);
84+
for (auto input : node->inputs()) {
85+
q.push(input);
86+
}
87+
}
88+
}
89+
return nullptr;
90+
}
91+
92+
torch::jit::Node* createCastNode(SegmentedBlock& seg_block, size_t index, bool is_input) {
93+
auto cast_raw_value = is_input ? seg_block.raw_inputs()[index] : seg_block.raw_outputs()[index];
94+
auto cast_subgraph_value = is_input ? seg_block.inputs()[index] : seg_block.outputs()[index];
95+
torch::jit::Node* cast_node = getUpstreamCastNode(cast_raw_value);
96+
auto g = seg_block.g();
97+
// if we can find upstream aten::to node, we use it's parameters for creating new cast node
98+
if (cast_node) {
99+
std::unordered_map<torch::jit::Value*, torch::jit::Value*> value_map;
100+
value_map.insert({cast_node->inputs()[0], cast_subgraph_value});
101+
if (!is_input) {
102+
// if this value is output, we need to cast it to int32
103+
auto const_val = g->insertConstant(3);
104+
if (cast_node->inputs()[1]->node()->output()->type()->kind() == torch::jit::TypeKind::DeviceObjType) {
105+
value_map.insert({cast_node->inputs()[2], const_val});
106+
} else {
107+
value_map.insert({cast_node->inputs()[1], const_val});
108+
}
109+
}
110+
auto env = [&](torch::jit::Value* v) { return util::getOrAddInputForValue(v, g, value_map); };
111+
cast_node = g->createClone(cast_node, env);
112+
// auto cast_node = g->prependNode(g->createClone(cast_node, env));
113+
} else {
114+
// if there is no explicit cast aten::to operation, we need to create a node
115+
auto const_type = is_input ? g->insertConstant(4) : g->insertConstant(3);
116+
auto const_zero = g->insertConstant(0);
117+
const_zero->setType(torch::jit::BoolType::get());
118+
auto none_val = g->insertNode(g->createNone())->output();
119+
cast_node = g->create(torch::jit::aten::to, {cast_subgraph_value, const_type, const_zero, const_zero, none_val});
120+
}
121+
return cast_node;
122+
}
123+
68124
void getSegmentsOutputByRunning(
69125
SegmentedBlock& seg_block,
70126
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& ivalues_maps,
@@ -150,16 +206,45 @@ void getSegmentsOutputByRunning(
150206
ivalues_maps[output] = jit_results[idx++];
151207
}
152208

209+
// auto int64 <=> int32 conversion
210+
if (seg_block.target() == SegmentedBlock::kTorch && partitioning_info.truncate_long_and_double) {
211+
// First, check if there is Int64 input
212+
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
213+
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
214+
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
215+
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
216+
if (t == at::kLong) {
217+
// we add a cast operation to cast the type to Int64
218+
auto cast_node = createCastNode(seg_block, i, true);
219+
seg_block.g()->prependNode(cast_node);
220+
seg_block.inputs()[i]->replaceAllUsesAfterNodeWith(cast_node, cast_node->outputs()[0]);
221+
}
222+
}
223+
}
224+
for (size_t i = 0; i < seg_block.outputs().size(); ++i) {
225+
if (ivalues_maps[seg_block.raw_outputs()[i]].isTensor()) {
226+
auto cur_ivalue = ivalues_maps[seg_block.raw_outputs()[i]];
227+
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
228+
if (t == at::kLong) {
229+
auto cast_node = createCastNode(seg_block, i, false);
230+
seg_block.g()->appendNode(cast_node);
231+
seg_block.g()->block()->replaceOutput(i, cast_node->outputs()[0]);
232+
}
233+
}
234+
}
235+
}
236+
153237
// set input shape for each segmented block so we wil use it in conversion process
154238
std::vector<ir::Input> input_shapes;
155239
std::vector<at::ScalarType> input_types;
156-
for (auto& i : seg_block.raw_inputs()) {
157-
if (ivalues_maps[i].isTensor()) {
240+
for (size_t i = 0; i < seg_block.inputs().size(); ++i) {
241+
if (ivalues_maps[seg_block.raw_inputs()[i]].isTensor()) {
158242
// set the input_shape and data_type
159243
// we can use a temp value here instead of replacing the values in ivalues_map since we only use ivalues_map for
160244
// shape inference
161-
auto cur_ivalue = ivalues_maps[i];
245+
auto cur_ivalue = ivalues_maps[seg_block.raw_inputs()[i]];
162246
at::ScalarType t = cur_ivalue.toTensor().scalar_type();
247+
163248
if (!partitioning_info.truncate_long_and_double && (t == at::kLong || t == at::kDouble)) {
164249
TORCHTRT_THROW_ERROR(
165250
"Unable to process subgraph input type of at::kLong/at::kDouble, try to compile model with truncate_long_and_double enabled");

tests/core/partitioning/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ partitioning_test(
4040
name = "test_resolve_nontensor_inputs",
4141
)
4242

43+
partitioning_test(
44+
name = "test_type_auto_conversion",
45+
)
46+
4347
cc_test(
4448
name = "test_loading_model",
4549
srcs = ["test_loading_model.cpp"],
@@ -112,5 +116,6 @@ test_suite(
112116
":test_shape_analysis",
113117
":test_stitched_graph",
114118
":test_tensorrt_conversion",
119+
":test_type_auto_conversion",
115120
],
116121
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#include <string>
2+
#include "core/partitioning/partitioning.h"
3+
#include "core/util/trt_util.h"
4+
#include "gtest/gtest.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
6+
#include "torch/script.h"
7+
8+
bool checkInsertedCastNodeNumber(torch_tensorrt::core::partitioning::SegmentedBlock& seg_block, int target_count) {
9+
int64_t cnt = 0;
10+
for (auto node : seg_block.nodes()) {
11+
if (node->kind().toQualString() == std::string("aten::to")) {
12+
cnt++;
13+
}
14+
}
15+
std::cout << "Found count of " << cnt << " inserted aten::to nodes, (looking for " << target_count
16+
<< " aten::to nodes)" << std::endl;
17+
18+
return target_count == cnt;
19+
}
20+
21+
TEST(Partitioning, ExplicitNodeAutoConversionCorrectly) {
22+
const auto graph = R"IR(
23+
graph(%0 : Tensor,
24+
%1 : Tensor):
25+
%2 : int = prim::Constant[value=4]()
26+
%3 : bool = prim::Constant[value=0]()
27+
%4 : NoneType = prim::Constant()
28+
%5 : int = prim::Constant[value=1]()
29+
%7: Tensor = aten::to(%1, %2, %3, %3, %4)
30+
%8 : Tensor = aten::mul(%0, %0)
31+
%9 : Tensor = aten::scatter(%8, %5, %7, %5)
32+
%10 : Tensor = aten::scatter(%7, %5, %7, %5)
33+
%12 : Tensor = aten::add(%10, %10, %5)
34+
return (%9, %12))IR";
35+
36+
auto g = std::make_shared<torch::jit::Graph>();
37+
torch::jit::parseIR(graph, g.get(), true);
38+
39+
torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
40+
partitioning_info.enabled = true;
41+
partitioning_info.forced_fallback_operators = {"aten::scatter"};
42+
partitioning_info.truncate_long_and_double = true;
43+
std::vector<torch_tensorrt::core::ir::Input> inputs;
44+
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));
45+
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));
46+
47+
std::unordered_map<const torch::jit::Value*, std::vector<torch_tensorrt::core::ir::Input>> inputs_map;
48+
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types;
49+
inputs_map.insert({g->inputs()[0], {inputs[0]}});
50+
input_types.insert({g->inputs()[0], {{at::kFloat}}});
51+
inputs_map.insert({g->inputs()[1], {inputs[1]}});
52+
input_types.insert({g->inputs()[1], {{at::kInt}}});
53+
54+
auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
55+
56+
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
57+
torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map);
58+
auto segmented_blocks = ctx.partitioned_blocks.begin()->second;
59+
60+
for (auto& seg_block : segmented_blocks) {
61+
LOG_DEBUG(seg_block << " cur seg block");
62+
}
63+
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[1], 2));
64+
}
65+
66+
TEST(Partitioning, ImplicitAutoConversionCorrectly) {
67+
const auto graph = R"IR(
68+
graph(%0 : Tensor):
69+
%2 : int = prim::Constant[value=0]()
70+
%4 : int = aten::size(%0, %2)
71+
%6 : Tensor = prim::NumToTensor(%4)
72+
%2 : int = prim::Constant[value=5]()
73+
%7 : int[] = prim::ListConstruct(%2, %2)
74+
%8 : bool = prim::Constant[value=0]()
75+
%9 : Tensor = aten::expand(%6, %7, %8)
76+
77+
%10 : Tensor = aten::mul(%9, %9)
78+
return (%10))IR";
79+
80+
auto g = std::make_shared<torch::jit::Graph>();
81+
torch::jit::parseIR(graph, g.get(), true);
82+
83+
torch_tensorrt::core::partitioning::PartitioningInfo partitioning_info;
84+
partitioning_info.enabled = true;
85+
partitioning_info.forced_fallback_operators = {"aten::expand"};
86+
partitioning_info.truncate_long_and_double = true;
87+
std::vector<torch_tensorrt::core::ir::Input> inputs;
88+
89+
inputs.push_back(torch_tensorrt::core::ir::Input({5, 5}));
90+
91+
std::unordered_map<const torch::jit::Value*, std::vector<torch_tensorrt::core::ir::Input>> inputs_map;
92+
std::unordered_map<const torch::jit::Value*, std::vector<c10::optional<at::ScalarType>>> input_types;
93+
inputs_map.insert({g->inputs()[0], {inputs[0]}});
94+
input_types.insert({g->inputs()[0], {{at::kFloat}}});
95+
96+
auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
97+
98+
torch_tensorrt::core::partitioning::PartitioningCtx ctx(g->block(), partitioning_info);
99+
torch_tensorrt::core::partitioning::partition(&ctx, input_ivalues_map);
100+
auto segmented_blocks = ctx.partitioned_blocks.begin()->second;
101+
102+
for (auto& seg_block : segmented_blocks) {
103+
LOG_DEBUG(seg_block << " cur seg block");
104+
}
105+
ASSERT_TRUE(checkInsertedCastNodeNumber(segmented_blocks[1], 2));
106+
}

0 commit comments

Comments
 (0)