Skip to content

Commit d479c98

Browse files
committed
fix: fix the fallback related issue after merging collection
Signed-off-by: Bo Wang <[email protected]>
1 parent 8385253 commit d479c98

File tree

1 file changed

+19
-51
lines changed

1 file changed

+19
-51
lines changed

core/partitioning/partitioning.cpp

Lines changed: 19 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,17 @@ struct usage_info {
1717
std::vector<size_t> tensorrt_use_id; // ids of segmented blocks which are of type TensorRT
1818
};
1919

20-
inline bool isTensorOrTensorList(torch::jit::Value* val) {
21-
return val->type()->isSubtypeOf(torch::jit::TensorType::get()) ||
22-
val->type()->isSubtypeOf(torch::jit::ListType::ofTensors());
23-
}
24-
25-
inline bool isTensorList(torch::jit::Value* val) {
26-
return val->type()->isSubtypeOf(torch::jit::ListType::ofTensors());
27-
}
28-
2920
inline bool isTensor(torch::jit::Value* val) {
3021
return val->type()->isSubtypeOf(torch::jit::TensorType::get());
3122
}
3223

24+
inline bool isListOrTuple(torch::jit::Value* val) {
25+
return val->type()->kind() == torch::jit::TypeKind::TupleType || val->type()->kind() == torch::jit::TypeKind::ListType;
26+
}
27+
3328
bool containNonTensorOutputs(torch::jit::Node* n) {
3429
for (auto output : n->outputs()) {
35-
if (!isTensorOrTensorList(output)) {
30+
if (!isTensor(output)) {
3631
return true;
3732
}
3833
}
@@ -68,6 +63,7 @@ std::vector<torch::jit::Node*> findModifyingNodes(
6863
return modifying_nodes;
6964
}
7065

66+
// this function is only used when a TRT segment produces nonTensor values which are used by later TRT segment
7167
std::vector<torch::jit::Node*> getDependencyNodes(
7268
const std::vector<torch::jit::Value*>& vals,
7369
const SegmentedBlock& seg_block) {
@@ -88,7 +84,7 @@ std::vector<torch::jit::Node*> getDependencyNodes(
8884
stk.insert(stk.end(), modifying_nodes.rbegin(), modifying_nodes.rend());
8985
stk.push_back(node);
9086
for (auto input : node->inputs()) {
91-
if (!isTensorOrTensorList(input)) {
87+
if (!isTensor(input)) {
9288
q.push(input);
9389
}
9490
}
@@ -113,15 +109,19 @@ void find_all_fallback_nodes(
113109
auto cur_node = q.front();
114110
q.pop();
115111
// for every node that produces this fallback node's NonTensor input, they should fallback too
112+
// Even collection feature is supported, since TRT List/Tuple output is not supported yet, the nodes
113+
// that produce List/Tuple still cannot be in TRT segment
116114
for (auto input : cur_node->inputs()) {
117115
if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant &&
118116
global_fallback_nodes.insert({input->node(), FallbackNodeType::kNON_TENSOR}).second) {
119117
q.push(input->node());
120118
}
121119
}
122120
// for every node that consumes this fallback node's NonTensor output, they should fallback too
121+
// Since collection feature is supported, we can have List/Tuple input for TRT segment, so we only
122+
// fallback the nodes that take inputs which are not Tensor/List/Tuple
123123
for (auto output : cur_node->outputs()) {
124-
if (!isTensor(output)) {
124+
if (!isTensor(output) && !isListOrTuple(output)) {
125125
for (auto use : output->uses()) {
126126
auto node = use.user;
127127
if (node->kind() != torch::jit::prim::Constant && global_fallback_nodes.insert({node, FallbackNodeType::kNON_TENSOR}).second) {
@@ -176,7 +176,7 @@ void registerSegmentsOutputs(PartitionedGraph& segmented_blocks, torch::jit::Blo
176176
if (std::find(seg_block.raw_inputs().begin(), seg_block.raw_inputs().end(), mini_graph_input) ==
177177
seg_block.raw_inputs().end() &&
178178
seg_block.contain_raw_value(mini_graph_input)) {
179-
if (!isTensorOrTensorList(mini_graph_input) && seg_block.target() == SegmentedBlock::kTensorRT)
179+
if (!isTensor(mini_graph_input) && seg_block.target() == SegmentedBlock::kTensorRT)
180180
continue;
181181
seg_block.registerOutput(mini_graph_input);
182182
}
@@ -242,36 +242,6 @@ bool check_node_fallback(torch::jit::Node* n, const std::unordered_map<torch::ji
242242
"Node fallback to Torch because the NonTensor dependencies with other fallback nodes: "
243243
<< util::node_info(n));
244244
}
245-
}
246-
return false;
247-
}
248-
249-
bool is_collection(torch::jit::Node* n) {
250-
for (auto out: n->outputs()) {
251-
if(out->type()->kind() == torch::jit::TypeKind::TupleType || out->type()->kind() == torch::jit::TypeKind::ListType) {
252-
return true;
253-
}
254-
}
255-
return false;
256-
}
257-
258-
bool should_run_in_trt(torch::jit::Node* n, const std::unordered_set<std::string>& torch_ops) {
259-
// If the op is not supported by the conversion phase it should run in PyTorch
260-
if (!conversion::OpSupported(n)) {
261-
LOG_GRAPH("Node not supported by conversion: " << util::node_info(n));
262-
return false;
263-
}
264-
265-
// If the user specifies the op to run in Torch it should run in PyTorch
266-
if (torch_ops.find(n->kind().toQualString()) != torch_ops.end()) {
267-
LOG_GRAPH("Node explicitly set to run in torch: " << util::node_info(n));
268-
return false;
269-
}
270-
271-
// If the user specifies the module containing this op to run in torch it should run in PyTorch
272-
const auto to_compile_sym = c10::Symbol::attr("to_compile");
273-
if (n->hasAttribute(to_compile_sym) && n->i(to_compile_sym) == (int64_t) false) {
274-
LOG_GRAPH("Node is within a module set to run in torch: " << util::node_info(n));
275245
return false;
276246
}
277247

@@ -390,19 +360,18 @@ PartitionedGraph segment_graph(
390360
find_min_block_size_fallback_nodes(block, global_fallback_nodes, min_block_size);
391361

392362
auto nodes = block->nodes();
393-
auto reverse_nodes = nodes.reverse(); // merge from output side to input side
394363
PartitionedGraph segmented_blocks;
395364

396365
// segment the nodes
397366
std::vector<torch::jit::Node*> in_prog_trt_blk_nodes, in_prog_pyt_blk_nodes;
398-
for (const auto n : reverse_nodes) {
367+
for (const auto n : nodes) {
399368
// Skip constant nodes as they are resources for both kinds of modules
400369
if (n->kind() == torch::jit::prim::Constant) {
401370
continue;
402371
}
403372
// the outputs of trt subgraph shouldn't be collections
404-
if (should_run_in_trt(n, forced_fallback_ops) && !(in_prog_trt_blk_nodes.size() == 0 && is_collection(n))) {
405-
in_prog_trt_blk_nodes.insert(in_prog_trt_blk_nodes.begin(), n);
373+
if (check_node_fallback(n, global_fallback_nodes)) {
374+
in_prog_trt_blk_nodes.push_back(n);
406375

407376
// If there is an active PyTorch block and we have passed the threshold for a valid TRT
408377
// block then segment and reset the active PyTorch block
@@ -418,7 +387,7 @@ PartitionedGraph segment_graph(
418387
LOG_DEBUG(
419388
"In progress TRT block does not meet minimum block size requirements, therefore folding into in progress PyTorch block");
420389
in_prog_pyt_blk_nodes.insert(
421-
in_prog_pyt_blk_nodes.begin(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
390+
in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
422391
}
423392
in_prog_trt_blk_nodes.clear();
424393
// if there is a prim::If then this if node will be encapsulated in a SegmentedBlock
@@ -437,14 +406,14 @@ PartitionedGraph segment_graph(
437406
finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
438407
}
439408
if (checkLoopEvaluatable(n)) {
440-
in_prog_trt_blk_nodes.insert(in_prog_trt_blk_nodes.begin(), n);
409+
in_prog_trt_blk_nodes.push_back(n);
441410
} else {
442411
auto loop_node = std::vector<torch::jit::Node*>{n};
443412
finalize_block(segmented_blocks, SegmentedBlock::kTorch, loop_node);
444413
}
445414
continue;
446415
}
447-
in_prog_pyt_blk_nodes.insert(in_prog_pyt_blk_nodes.begin(), n);
416+
in_prog_pyt_blk_nodes.push_back(n);
448417
}
449418
}
450419

@@ -459,7 +428,6 @@ PartitionedGraph segment_graph(
459428
in_prog_pyt_blk_nodes.end(), in_prog_trt_blk_nodes.begin(), in_prog_trt_blk_nodes.end());
460429
finalize_block(segmented_blocks, SegmentedBlock::kTorch, in_prog_pyt_blk_nodes);
461430
}
462-
std::reverse(segmented_blocks.begin(), segmented_blocks.end());
463431
return segmented_blocks;
464432
}
465433

0 commit comments

Comments
 (0)