Skip to content

Commit 418d1e5

Browse files
committed
refactor: still fallback when a trt segment has tuple/list input/output
Signed-off-by: Bo Wang <[email protected]>
1 parent d479c98 commit 418d1e5

File tree

1 file changed

+3
-10
lines changed

1 file changed

+3
-10
lines changed

core/partitioning/partitioning.cpp

+3-10
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,6 @@ inline bool isTensor(torch::jit::Value* val) {
2121
return val->type()->isSubtypeOf(torch::jit::TensorType::get());
2222
}
2323

24-
inline bool isListOrTuple(torch::jit::Value* val) {
25-
return val->type()->kind() == torch::jit::TypeKind::TupleType || val->type()->kind() == torch::jit::TypeKind::ListType;
26-
}
27-
2824
bool containNonTensorOutputs(torch::jit::Node* n) {
2925
for (auto output : n->outputs()) {
3026
if (!isTensor(output)) {
@@ -109,22 +105,19 @@ void find_all_fallback_nodes(
109105
auto cur_node = q.front();
110106
q.pop();
111107
// for every node that produces this fallback node's NonTensor input, they should fallback too
112-
// Even collection feature is supported, since TRT List/Tuple output is not supported yet, the nodes
113-
// that produce List/Tuple still cannot be in TRT segment
114108
for (auto input : cur_node->inputs()) {
115109
if (!isTensor(input) && input->node()->kind() != torch::jit::prim::Constant &&
116110
global_fallback_nodes.insert({input->node(), FallbackNodeType::kNON_TENSOR}).second) {
117111
q.push(input->node());
118112
}
119113
}
120114
// for every node that consumes this fallback node's NonTensor output, they should fallback too
121-
// Since collection feature is supported, we can have List/Tuple input for TRT segment, so we only
122-
// fallback the nodes that take inputs which are not Tensor/List/Tuple
123115
for (auto output : cur_node->outputs()) {
124-
if (!isTensor(output) && !isListOrTuple(output)) {
116+
if (!isTensor(output)) {
125117
for (auto use : output->uses()) {
126118
auto node = use.user;
127-
if (node->kind() != torch::jit::prim::Constant && global_fallback_nodes.insert({node, FallbackNodeType::kNON_TENSOR}).second) {
119+
if (node->kind() != torch::jit::prim::Constant &&
120+
global_fallback_nodes.insert({node, FallbackNodeType::kNON_TENSOR}).second) {
128121
q.push(node);
129122
}
130123
}

0 commit comments

Comments
 (0)