Skip to content

Commit 253b3c7

Browse files
authored
Merge pull request #1225 from pytorch/fix_collection_partitioning
fix: fix the error that collection input segmented into trt subgraph
2 parents 9bce034 + 6d0b1d3 commit 253b3c7

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

core/partitioning/partitioning.cpp

+17-3
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,26 @@ std::vector<torch::jit::Node*> getDependencyNodes(
9090
return stk;
9191
}
9292

93-
void find_nontensor_output_nodes(
93+
// check if the input and output of the graph is Tensor after collection is enabled. If it is, then fallback related
94+
// nodes
95+
void fallback_graph_nontensor_in_out(
9496
torch::jit::Block* block,
9597
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
98+
// fallback nodes that produce entire graph's nonTensor output
9699
for (auto i : block->outputs()) {
97100
if (!isTensor(i)) {
98101
global_fallback_nodes.insert({i->node(), FallbackNodeType::kNON_TENSOR});
99102
}
100103
}
104+
105+
// fallback nodes that consume entire graph's nonTensor input
106+
for (auto i : block->inputs()) {
107+
if (!isTensor(i)) {
108+
for (auto use : i->uses()) {
109+
global_fallback_nodes.insert({use.user, FallbackNodeType::kNON_TENSOR});
110+
}
111+
}
112+
}
101113
}
102114

103115
void find_all_fallback_nodes(
@@ -202,6 +214,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo
202214
}
203215
}
204216
}
217+
205218
std::for_each(segmented_blocks.begin(), segmented_blocks.end(), [](SegmentedBlock& seg_block) {
206219
torch::jit::EliminateDeadCode(seg_block.g());
207220
});
@@ -440,8 +453,9 @@ PartitionedGraph Partition(
440453
const PartitionInfo& partition_info,
441454
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
442455
LOG_DEBUG(partition_info);
443-
// if there is nonTensor output for the entire graph, fallback the node that produces this nonTensor output
444-
find_nontensor_output_nodes(block, global_fallback_nodes);
456+
// if there is nonTensor input/output for the entire graph, fallback the node that consumes/produces this nonTensor
457+
// output
458+
fallback_graph_nontensor_in_out(block, global_fallback_nodes);
445459

446460
// segment lowering global graph into blocks
447461
LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks");

0 commit comments

Comments
 (0)