Skip to content

Commit e9e2799

Browse files
committed
feat: support min_block_size != 1 casued fallback nodes resegmentation
Signed-off-by: Bo Wang <[email protected]>
1 parent 1a22204 commit e9e2799

File tree

2 files changed

+70
-12
lines changed

2 files changed

+70
-12
lines changed

core/compiler.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ GraphAndMapping ConstructFallbackGraph(
240240
}
241241

242242
for (auto& seg_block : segmented_blocks) {
243-
LOG_INFO(*seg_block.g() << "(GraphInSegmentedBlock)\n");
243+
LOG_INFO(seg_block << "(GraphInSegmentedBlock)\n");
244244
std::ostringstream trt_engine_id;
245245
trt_engine_id << reinterpret_cast<const int*>(&seg_block);
246246

@@ -436,7 +436,7 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
436436
auto graph_and_mapping =
437437
ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params, fallback_nodes);
438438
new_g = graph_and_mapping.first;
439-
LOG_INFO("Segmented Graph: " << *new_g);
439+
LOG_INFO("Graph after Fallback: " << *new_g);
440440

441441
// if there is no tensorrt engine self in fallback graph, there is no conversion, we just return the initial
442442
// module

core/partitioning/partitioning.cpp

+68-10
Original file line numberDiff line numberDiff line change
@@ -98,9 +98,13 @@ std::vector<torch::jit::Node*> getDependencyNodes(
9898
return stk;
9999
}
100100

101-
void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
101+
void find_all_fallback_nodes(
102+
std::unordered_map<torch::jit::Node*, int>& initial_fallback_nodes,
103+
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
104+
// initial_fallback_nodes are the fallback nodes that we have before we run BFS in this function
105+
// global_fallback_nodes are the fallback nodes that we maintain globally
102106
std::queue<torch::jit::Node*> q;
103-
for (auto& node : fallback_nodes) {
107+
for (auto& node : initial_fallback_nodes) {
104108
q.push(node.first);
105109
}
106110

@@ -111,7 +115,7 @@ void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallbac
111115
// for every node that produces this fallback node's NonTensor input, they should fallback too
112116
for (auto input : cur_node->inputs()) {
113117
if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant &&
114-
fallback_nodes.insert({input->node(), 4}).second) {
118+
global_fallback_nodes.insert({input->node(), 4}).second) {
115119
q.push(input->node());
116120
}
117121
}
@@ -120,7 +124,7 @@ void find_all_fallback_nodes(std::unordered_map<torch::jit::Node*, int>& fallbac
120124
if (!isTensor(output)) {
121125
for (auto use : output->uses()) {
122126
auto node = use.user;
123-
if (node->kind() != torch::jit::prim::Constant && fallback_nodes.insert({node, 4}).second) {
127+
if (node->kind() != torch::jit::prim::Constant && global_fallback_nodes.insert({node, 4}).second) {
124128
q.push(node);
125129
}
126130
}
@@ -231,6 +235,8 @@ bool check_node_fallback(torch::jit::Node* n, const std::unordered_map<torch::ji
231235
LOG_GRAPH("Node explicitly set to run in torch: " << util::node_info(n));
232236
} else if (fallback_nodes.at(n) == 2) {
233237
LOG_GRAPH("Node is within a module set to run in torch: " << util::node_info(n));
238+
} else if (fallback_nodes.at(n) == 3) {
239+
LOG_GRAPH("Node fallback to Torch because of min_block_size" << util::node_info(n));
234240
} else {
235241
LOG_GRAPH(
236242
"Node fallback to Torch because the NonTensor dependencies with other fallback nodes: "
@@ -284,22 +290,74 @@ void get_fallback_nodes(
284290
return;
285291
}
286292

293+
std::vector<torch::jit::Node*> traverse_nodes_for_min_block_size(
294+
torch::jit::Block* block,
295+
const std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes,
296+
size_t min_block_size) {
297+
auto nodes = block->nodes();
298+
std::vector<torch::jit::Node*> cur_trt_nodes;
299+
std::vector<torch::jit::Node*> min_block_fallback_nodes;
300+
for (const auto n : nodes) {
301+
if (n->kind() == torch::jit::prim::Constant)
302+
continue;
303+
304+
// check if current node fallback or not
305+
if (!global_fallback_nodes.count(n)) {
306+
// if this node is not in fallback nodes, then it's in trt segments
307+
cur_trt_nodes.push_back(n);
308+
} else {
309+
if (cur_trt_nodes.size() < min_block_size) {
310+
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
311+
}
312+
cur_trt_nodes.clear();
313+
}
314+
}
315+
if (cur_trt_nodes.size() < min_block_size) {
316+
min_block_fallback_nodes.insert(min_block_fallback_nodes.end(), cur_trt_nodes.begin(), cur_trt_nodes.end());
317+
}
318+
return min_block_fallback_nodes;
319+
}
320+
321+
void find_min_block_size_fallback_nodes(
322+
torch::jit::Block* block,
323+
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes,
324+
size_t min_block_size) {
325+
// first traverse all the nodes to find the initial nodes that don't meet the min_block_size requirement
326+
auto min_block_fallback_nodes = traverse_nodes_for_min_block_size(block, global_fallback_nodes, min_block_size);
327+
std::unordered_map<torch::jit::Node*, int> initial_fallback_nodes;
328+
329+
// keep fallback until all segments meet the min_block_size requirement
330+
while (!min_block_fallback_nodes.empty()) {
331+
for (const auto i : min_block_fallback_nodes) {
332+
initial_fallback_nodes.insert({i, 3});
333+
}
334+
global_fallback_nodes.insert(initial_fallback_nodes.begin(), initial_fallback_nodes.end());
335+
// find the fallback nodes because of dependency with min_block_size caused fallback nodes
336+
find_all_fallback_nodes(initial_fallback_nodes, global_fallback_nodes);
337+
// keep traverse the graph until there is no node fallback because of min_block_size
338+
min_block_fallback_nodes = traverse_nodes_for_min_block_size(block, global_fallback_nodes, min_block_size);
339+
}
340+
}
341+
287342
PartitionedGraph segment_graph(
288343
torch::jit::Block* block,
289344
const PartitionInfo& partition_info,
290-
std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
345+
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
291346
auto min_block_size = partition_info.min_block_size;
292347
std::unordered_set<std::string> forced_fallback_ops(
293348
partition_info.forced_fallback_operators.begin(), partition_info.forced_fallback_operators.end());
294349

295350
// get the initial fallback nodes (nodes that are unsupported or forced fallback)
296-
get_fallback_nodes(block, forced_fallback_ops, fallback_nodes);
351+
get_fallback_nodes(block, forced_fallback_ops, global_fallback_nodes);
297352

298353
// For fallback nodes, if it consumes any NonTensor inputs or TensorList inputs, then the node that produces this
299354
// input should also fallback Similarly, if it produces any NonTensor outputs or TensorList outputs, then the node
300355
// that produces this input should also fallback
301356
// TODO: don't need to fallback the TensorList related nodes once the collection feature is supported
302-
find_all_fallback_nodes(fallback_nodes);
357+
find_all_fallback_nodes(global_fallback_nodes, global_fallback_nodes);
358+
359+
// find all fallback nodes because of the min_block_size requirement
360+
find_min_block_size_fallback_nodes(block, global_fallback_nodes, min_block_size);
303361

304362
auto nodes = block->nodes();
305363

@@ -313,7 +371,7 @@ PartitionedGraph segment_graph(
313371
continue;
314372
}
315373

316-
if (check_node_fallback(n, fallback_nodes)) {
374+
if (check_node_fallback(n, global_fallback_nodes)) {
317375
in_prog_trt_blk_nodes.push_back(n);
318376

319377
// If there is an active PyTorch block and we have passed the threshold for a valid TRT
@@ -379,11 +437,11 @@ PartitionedGraph Partition(
379437
torch::jit::Block* block,
380438
std::unordered_map<const torch::jit::Value*, torch::jit::IValue>& example_tensor_map,
381439
const PartitionInfo& partition_info,
382-
std::unordered_map<torch::jit::Node*, int>& fallback_nodes) {
440+
std::unordered_map<torch::jit::Node*, int>& global_fallback_nodes) {
383441
LOG_DEBUG(partition_info);
384442
// segment lowering global graph into blocks
385443
LOG_DEBUG("Parititioning source module into PyTorch and TensorRT sub blocks");
386-
PartitionedGraph segmented_blocks = segment_graph(block, partition_info, fallback_nodes);
444+
PartitionedGraph segmented_blocks = segment_graph(block, partition_info, global_fallback_nodes);
387445

388446
// It's possible that some TensorRT blocks have nonTensor inputs/output because they are interleaved by Torch blocks
389447

0 commit comments

Comments
 (0)