Skip to content

Commit cc10876

Browse files
committed
fix: Fix a core partitioning algo bug where non-tensor input segments are not updated correctly
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 1aa492f commit cc10876

File tree

3 files changed

+35
-38
lines changed

3 files changed

+35
-38
lines changed

Diff for: core/partitioning/partitioning.cpp

+7-8
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg
115115
pytorch_nodes.push_back(n);
116116
prev_non_tensor_outputs = containNonTensorOutputs(n);
117117
} else {
118-
// If pytorch_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a
118+
// If pytorch_nodes is not empty, the previous nodes were all pytorch_nodes. Construct a
119119
// Pytorch segmented_block and clear the pytorch_nodes list to be later used for new Pytorch segments.
120120
if (!pytorch_nodes.empty()) {
121121
new_seg_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
@@ -132,6 +132,7 @@ std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg
132132
new_seg_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
133133
}
134134
}
135+
135136
return std::move(new_seg_blocks);
136137
}
137138

@@ -159,6 +160,7 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar
159160
}
160161
}
161162

163+
// For each non-tensor value in the usage_counts map, keep updating the produce_id to the earliest segmented block that has/produces it.
162164
for (auto& use : usage_counts) {
163165
// Set the produce_id to the segmented block index that contains/produces this non-tensor torch::jit::Value
164166
if (segmented_blocks[i].contain_raw_value(use.first)) {
@@ -167,6 +169,7 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar
167169
}
168170
}
169171

172+
170173
std::unordered_set<int> updated_segments;
171174
for (auto& use : usage_counts) {
172175
auto use_info = use.second;
@@ -178,9 +181,8 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks) { // , std::shar
178181
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
179182
// TRTorch doesn't support non-tensor inputs for a module.
180183
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]);
181-
segmented_blocks.erase(segmented_blocks.begin() + first_torch_id);
182-
segmented_blocks.insert(
183-
segmented_blocks.begin() + first_torch_id, to_inject_blocks.begin(), to_inject_blocks.end());
184+
auto next_iter = segmented_blocks_list.erase(idx_to_iter[first_torch_id]);
185+
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
184186
updated_segments.insert(first_torch_id);
185187
}
186188
}
@@ -314,6 +316,7 @@ std::vector<SegmentedBlock> segment_graph(torch::jit::Block* block, const Partit
314316
segmented_blocks.emplace_back(SegmentedBlock::kTorch, std::vector<torch::jit::Node*>{n});
315317
continue;
316318
} else if (n->kind() == torch::jit::prim::Loop) {
319+
317320
if (!pytorch_nodes.empty()) {
318321
segmented_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
319322
pytorch_nodes.clear();
@@ -347,19 +350,15 @@ std::vector<SegmentedBlock> Partition(
347350
const PartitionInfo& partition_info) {
348351
LOG_DEBUG(partition_info);
349352
// segment lowering global graph into blocks
350-
LOG_DEBUG("Partitioning graph into PyTorch and TensorRT segmented blocks");
351353
std::vector<SegmentedBlock> segmented_blocks = segment_graph(block, partition_info);
352354

353355
// resolve nonTensor inputs/outputs
354-
LOG_DEBUG("Resolving non-tensor type inputs/outputs (eg: int/float types)");
355356
resolveNonTensorInputs(segmented_blocks);
356357

357358
// register input/output torch::jit::Value for segmented graphs
358-
LOG_DEBUG("Registering input/outputs for segmented blocks");
359359
registerSegmentsOutputs(segmented_blocks, block);
360360

361361
// run shape analysis on each segmented block
362-
LOG_DEBUG("Running shape analysis for all the segmented blocks");
363362
runShapeAnalysis(segmented_blocks, input_ivalues_map);
364363

365364
return segmented_blocks;

Diff for: core/partitioning/shape_analysis.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ void getSegmentsOutputByRunning(
5656
for (auto& input : seg_block.raw_inputs()) {
5757
TRTORCH_CHECK(
5858
ivalues_maps.count(input),
59-
"Could not find torch::jit::Value* " << input->debugName() << " in lowering graph for mini graph input.\n");
59+
"Could not find torch::jit::Value* " << input->debugName() << " produced from " << util::node_info(input->node()) << " in lowering graph for mini graph input.\n");
6060
if (input->node()->kind() == torch::jit::prim::Param) {
6161
jit_inputs_ivalues.push_back(ivalues_maps[input]);
6262
} else if (input->type()->isSubtypeOf(torch::jit::TensorType::get())) {
@@ -108,10 +108,8 @@ void runShapeAnalysis(
108108
std::unordered_map<torch::jit::Value*, torch::jit::IValue>& ivalues_maps) {
109109
// register every segment's input shape, and it's running output IValues
110110
for (auto& seg_block : segmented_blocks) {
111-
LOG_DEBUG("Segmented graph: " << *seg_block.g());
112111
torch::jit::ConstantPooling(seg_block.g());
113112
getSegmentsOutputByRunning(seg_block, ivalues_maps);
114-
LOG_DEBUG("=================");
115113
}
116114
return;
117115
}

Diff for: tests/core/partitioning/test_loop_fallback.cpp

+27-27
Original file line numberDiff line numberDiff line change
@@ -33,30 +33,30 @@ TEST(Partitioning, CheckLoopFallbackEvalCompilesCorrectly) {
3333
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
3434
}
3535

36-
// TEST(Partitioning, CheckLoopFallbackNoEvalCompilesCorrectly) {
37-
// torch::jit::script::Module mod;
38-
// try {
39-
// mod = torch::jit::load("tests/modules/loop_fallback_no_eval_scripted.jit.pt");
40-
// } catch (const c10::Error& e) {
41-
// std::cerr << "error loading the model\n";
42-
// return;
43-
// }
44-
//
45-
// const std::vector<std::vector<int64_t>> input_shapes = {{1, 10}};
46-
// std::vector<torch::jit::IValue> jit_inputs_ivalues;
47-
// std::vector<torch::jit::IValue> trt_inputs_ivalues;
48-
// for (auto in_shape : input_shapes) {
49-
// auto in = at::randint(5, in_shape, {at::kCUDA});
50-
// jit_inputs_ivalues.push_back(in.clone());
51-
// trt_inputs_ivalues.push_back(in.clone());
52-
// }
53-
//
54-
// std::vector<trtorch::core::ir::Input> input_ranges{trtorch::core::ir::Input({1, 10})};
55-
// trtorch::core::CompileSpec cfg(input_ranges);
56-
// cfg.partition_info.enabled = true;
57-
//
58-
// auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
59-
// auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
60-
// auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
61-
// ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
62-
// }
36+
TEST(Partitioning, CheckLoopFallbackNoEvalCompilesCorrectly) {
37+
torch::jit::script::Module mod;
38+
try {
39+
mod = torch::jit::load("tests/modules/loop_fallback_no_eval_scripted.jit.pt");
40+
} catch (const c10::Error& e) {
41+
std::cerr << "error loading the model\n";
42+
return;
43+
}
44+
45+
const std::vector<std::vector<int64_t>> input_shapes = {{1, 10}};
46+
std::vector<torch::jit::IValue> jit_inputs_ivalues;
47+
std::vector<torch::jit::IValue> trt_inputs_ivalues;
48+
for (auto in_shape : input_shapes) {
49+
auto in = at::randint(5, in_shape, {at::kCUDA});
50+
jit_inputs_ivalues.push_back(in.clone());
51+
trt_inputs_ivalues.push_back(in.clone());
52+
}
53+
54+
std::vector<trtorch::core::ir::Input> input_ranges{trtorch::core::ir::Input({1, 10})};
55+
trtorch::core::CompileSpec cfg(input_ranges);
56+
cfg.partition_info.enabled = true;
57+
58+
auto jit_results = mod.forward(jit_inputs_ivalues).toTensor();
59+
auto trt_mod = trtorch::core::CompileGraph(mod, cfg);
60+
auto trt_results = trt_mod.forward(trt_inputs_ivalues).toTensor();
61+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results, trt_results, 2e-6));
62+
}

0 commit comments

Comments
 (0)