Skip to content

Commit 9bce034

Browse files
authored
Merge pull request #1220 from pytorch/fix_collection_partitioning
2 parents 5cff257 + f866dba commit 9bce034

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

core/partitioning/partitioning.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,16 @@ std::vector<torch::jit::Node*> getDependencyNodes(
9090
return stk;
9191
}
9292

93+
void find_nontensor_output_nodes(
94+
torch::jit::Block* block,
95+
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
96+
for (auto i : block->outputs()) {
97+
if (!isTensor(i)) {
98+
global_fallback_nodes.insert({i->node(), FallbackNodeType::kNON_TENSOR});
99+
}
100+
}
101+
}
102+
93103
void find_all_fallback_nodes(
94104
std::unordered_map<torch::jit::Node*, int>& initial_fallback_nodes,
95105
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
@@ -430,6 +440,9 @@ PartitionedGraph Partition(
430440
const PartitionInfo& partition_info,
431441
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
432442
LOG_DEBUG(partition_info);
443+
// if there is nonTensor output for the entire graph, fallback the node that produces this nonTensor output
444+
find_nontensor_output_nodes(block, global_fallback_nodes);
445+
433446
// segment lowering global graph into blocks
434447
LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks");
435448
PartitionedGraph segmented_blocks = segment_graph(block, partition_info, global_fallback_nodes);

0 commit comments

Comments
 (0)