@@ -115,7 +115,7 @@ void find_all_fallback_nodes(
115
115
// for every node that produces this fallback node's NonTensor input, they should fallback too
116
116
for (auto input : cur_node->inputs ()) {
117
117
if (!isTensor (input) && input->node ()->kind () != torch::jit::prim::Constant &&
118
- global_fallback_nodes.insert ({input->node (), 4 }).second ) {
118
+ global_fallback_nodes.insert ({input->node (), FallbackNodeType:: kNON_TENSOR }).second ) {
119
119
q.push (input->node ());
120
120
}
121
121
}
@@ -124,7 +124,7 @@ void find_all_fallback_nodes(
124
124
if (!isTensor (output)) {
125
125
for (auto use : output->uses ()) {
126
126
auto node = use.user ;
127
- if (node->kind () != torch::jit::prim::Constant && global_fallback_nodes.insert ({node, 4 }).second ) {
127
+ if (node->kind () != torch::jit::prim::Constant && global_fallback_nodes.insert ({node, FallbackNodeType:: kNON_TENSOR }).second ) {
128
128
q.push (node);
129
129
}
130
130
}
@@ -229,13 +229,13 @@ bool checkLoopEvaluatable(torch::jit::Node* n) {
229
229
230
230
bool check_node_fallback (torch::jit::Node* n, const std::unordered_map<torch::jit::Node*, int >& fallback_nodes) {
231
231
if (fallback_nodes.count (n)) {
232
- if (fallback_nodes.at (n) == 0 ) {
232
+ if (fallback_nodes.at (n) == FallbackNodeType:: kUNSUPPORTED ) {
233
233
LOG_GRAPH (" Node not supported by conversion: " << util::node_info (n));
234
- } else if (fallback_nodes.at (n) == 1 ) {
234
+ } else if (fallback_nodes.at (n) == FallbackNodeType:: kOPERATOR_FALLBACK ) {
235
235
LOG_GRAPH (" Node explicitly set to run in torch: " << util::node_info (n));
236
- } else if (fallback_nodes.at (n) == 2 ) {
236
+ } else if (fallback_nodes.at (n) == FallbackNodeType:: kMODULE_FALLBACK ) {
237
237
LOG_GRAPH (" Node is within a module set to run in torch: " << util::node_info (n));
238
- } else if (fallback_nodes.at (n) == 3 ) {
238
+ } else if (fallback_nodes.at (n) == FallbackNodeType:: kMIN_BLOCK_FALLBACK ) {
239
239
LOG_GRAPH (" Node fallback to Torch because of min_block_size" << util::node_info (n));
240
240
} else {
241
241
LOG_GRAPH (
@@ -273,18 +273,18 @@ void get_fallback_nodes(
273
273
274
274
// If the op is not supported by the conversion phase it should run in PyTorch
275
275
if (!conversion::OpSupported (n)) {
276
- fallback_nodes.insert ({n, 0 });
276
+ fallback_nodes.insert ({n, FallbackNodeType:: kUNSUPPORTED });
277
277
}
278
278
279
279
// If the user specifies the op to run in Torch it should run in PyTorch
280
280
if (forced_fallback_ops.find (n->kind ().toQualString ()) != forced_fallback_ops.end ()) {
281
- fallback_nodes.insert ({n, 1 });
281
+ fallback_nodes.insert ({n, FallbackNodeType:: kOPERATOR_FALLBACK });
282
282
}
283
283
284
284
// If the user specifies the module containing this op to run in torch it should run in PyTorch
285
285
const auto to_compile_sym = c10::Symbol::attr (" to_compile" );
286
286
if (n->hasAttribute (to_compile_sym) && n->i (to_compile_sym) == (int64_t ) false ) {
287
- fallback_nodes.insert ({n, 2 });
287
+ fallback_nodes.insert ({n, FallbackNodeType:: kMODULE_FALLBACK });
288
288
}
289
289
}
290
290
return ;
@@ -329,7 +329,7 @@ void find_min_block_size_fallback_nodes(
329
329
// keep fallback until all segments meet the min_block_size requirement
330
330
while (!min_block_fallback_nodes.empty ()) {
331
331
for (const auto i : min_block_fallback_nodes) {
332
- initial_fallback_nodes.insert ({i, 3 });
332
+ initial_fallback_nodes.insert ({i, FallbackNodeType:: kMIN_BLOCK_FALLBACK });
333
333
}
334
334
global_fallback_nodes.insert (initial_fallback_nodes.begin (), initial_fallback_nodes.end ());
335
335
// find the fallback nodes because of dependency with min_block_size caused fallback nodes
0 commit comments