@@ -21,10 +21,6 @@ inline bool isTensor(torch::jit::Value* val) {
21
21
return val->type ()->isSubtypeOf (torch::jit::TensorType::get ());
22
22
}
23
23
24
- inline bool isListOrTuple (torch::jit::Value* val) {
25
- return val->type ()->kind () == torch::jit::TypeKind::TupleType || val->type ()->kind () == torch::jit::TypeKind::ListType;
26
- }
27
-
28
24
bool containNonTensorOutputs (torch::jit::Node* n) {
29
25
for (auto output : n->outputs ()) {
30
26
if (!isTensor (output)) {
@@ -109,22 +105,19 @@ void find_all_fallback_nodes(
109
105
auto cur_node = q.front ();
110
106
q.pop ();
111
107
// for every node that produces this fallback node's NonTensor input, they should fallback too
112
- // Even collection feature is supported, since TRT List/Tuple output is not supported yet, the nodes
113
- // that produce List/Tuple still cannot be in TRT segment
114
108
for (auto input : cur_node->inputs ()) {
115
109
if (!isTensor (input) && input->node ()->kind () != torch::jit::prim::Constant &&
116
110
global_fallback_nodes.insert ({input->node (), FallbackNodeType::kNON_TENSOR }).second ) {
117
111
q.push (input->node ());
118
112
}
119
113
}
120
114
// for every node that consumes this fallback node's NonTensor output, they should fallback too
121
- // Since collection feature is supported, we can have List/Tuple input for TRT segment, so we only
122
- // fallback the nodes that take inputs which are not Tensor/List/Tuple
123
115
for (auto output : cur_node->outputs ()) {
124
- if (!isTensor (output) && ! isListOrTuple (output) ) {
116
+ if (!isTensor (output)) {
125
117
for (auto use : output->uses ()) {
126
118
auto node = use.user ;
127
- if (node->kind () != torch::jit::prim::Constant && global_fallback_nodes.insert ({node, FallbackNodeType::kNON_TENSOR }).second ) {
119
+ if (node->kind () != torch::jit::prim::Constant &&
120
+ global_fallback_nodes.insert ({node, FallbackNodeType::kNON_TENSOR }).second ) {
128
121
q.push (node);
129
122
}
130
123
}
0 commit comments