@@ -39,7 +39,7 @@ void NotateModuleForFallback(
39
39
if (n->kind () == torch::jit::prim::GetAttr) {
40
40
auto out_type = unmangle_cls_name (c10::toString (n->output (0 )->type ()));
41
41
if (forced_fallback_modules.find (out_type) != forced_fallback_modules.end ()) {
42
- LOG_DEBUG (
42
+ LOG_GRAPH (
43
43
" Notating module for fallback: " << n->s (c10::attr::name) << " (" << out_type << " ) [owner: " << mod_name
44
44
<< " (" << cls_name << " )]" );
45
45
auto uses = n->output (0 )->uses ();
@@ -58,11 +58,32 @@ void NotateModuleForFallback(
58
58
}
59
59
60
60
if (changed_mod) {
61
- LOG_DEBUG (" Notated graph: " << *g);
61
+ LOG_GRAPH (" Notated graph: " << *g);
62
62
}
63
63
64
- for (const auto sub_mod : mod.named_children ()) {
65
- NotateModuleForFallback (sub_mod.value , sub_mod.name , method_name, forced_fallback_modules);
64
+ if (mod.named_children ().size () > 0 ) {
65
+ for (const auto n : nodes) {
66
+ std::string sub_method_name = " " ;
67
+ if (n->kind () == torch::jit::prim::CallMethod) {
68
+ sub_method_name = n->s (c10::Symbol::attr (" name" ));
69
+ auto sub_mod_val = n->input (0 );
70
+ auto sub_mod_src_n = sub_mod_val->node ();
71
+ if (!sub_mod_src_n->hasAttributeS (" name" )) {
72
+ LOG_GRAPH (" Node: " << util::node_info (sub_mod_src_n) << " manages a module with no name, skipping" );
73
+ break ;
74
+ }
75
+ auto sub_mod_name = sub_mod_src_n->s (c10::Symbol::attr (" name" ));
76
+ for (const auto sub_mod : mod.named_children ()) {
77
+ // Theres probably a way to directly access the module we care about
78
+ if (sub_mod.name == sub_mod_name) {
79
+ LOG_GRAPH (
80
+ " Looking at <module>.<method>() next: " << sub_mod_name << " ." << sub_method_name
81
+ << " () (lowering.passes.NotateModuleForFallback)" );
82
+ NotateModuleForFallback (sub_mod.value , sub_mod.name , sub_method_name, forced_fallback_modules);
83
+ }
84
+ }
85
+ }
86
+ }
66
87
}
67
88
}
68
89
@@ -74,23 +95,23 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_del
74
95
auto n = *it;
75
96
if (!mark.top () && n->kind () == torch::jit::prim::Enter && n->hasAttributeS (" compilation_edge" )) {
76
97
if (n->s (c10::Symbol::attr (" compilation_edge" )) == " start" ) {
77
- LOG_DEBUG (" Starting to mark new segmented block targeted for torch" );
98
+ LOG_GRAPH (" Starting to mark new segmented block targeted for torch" );
78
99
mark.push (true );
79
100
if (delete_delims) {
80
101
it.destroyCurrent ();
81
102
}
82
103
}
83
104
} else if (mark.top () && n->kind () == torch::jit::prim::Enter && n->hasAttributeS (" compilation_edge" )) {
84
105
if (n->s (c10::Symbol::attr (" compilation_edge" )) == " start" ) {
85
- LOG_DEBUG (" Found the start of another segmented block targeted for torch while actively marking a block" );
106
+ LOG_GRAPH (" Found the start of another segmented block targeted for torch while actively marking a block" );
86
107
mark.push (true );
87
108
if (delete_delims) {
88
109
it.destroyCurrent ();
89
110
}
90
111
}
91
112
} else if (mark.top () && n->kind () == torch::jit::prim::Exit && n->hasAttributeS (" compilation_edge" )) {
92
113
if (n->s (c10::Symbol::attr (" compilation_edge" )) == " end" ) {
93
- LOG_DEBUG (" Found the end of segmented block targeted for torch while actively marking a block" );
114
+ LOG_GRAPH (" Found the end of segmented block targeted for torch while actively marking a block" );
94
115
mark.pop ();
95
116
if (delete_delims) {
96
117
it.destroyCurrent ();
@@ -106,7 +127,7 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_del
106
127
}
107
128
}
108
129
109
- LOG_DEBUG (" After marking operations for torch fallback: " << *g);
130
+ LOG_GRAPH (" After marking operations for torch fallback: " << *g);
110
131
}
111
132
112
133
} // namespace passes
0 commit comments