Skip to content

Commit 2fc612d

Browse files
committed
fix(//core/lowering): Fixes module level fallback recursion
This commit fixes module level fallback by using method calls to determine modules to recurse down too. This should be robust to names other than forward used for methods as well as ignoring functional modules. Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 722aa94 commit 2fc612d

File tree

2 files changed

+37
-15
lines changed

2 files changed

+37
-15
lines changed

Diff for: core/lowering/lowering.cpp

+11-10
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
3737
torch::jit::EliminateCommonSubexpression(g);
3838
}
3939
torch::jit::EliminateDeadCode(g);
40-
passes::MarkNodesForFallback(g, true);
40+
if (lower_info.forced_fallback_modules.size() > 0) {
41+
passes::MarkNodesForFallback(g, true);
42+
}
4143
passes::UnpackHardSwish(g);
4244
passes::EliminateExceptionOrPassPattern(g);
4345
passes::ReduceToOperation(g);
@@ -60,12 +62,13 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
6062
LOG_GRAPH(*g);
6163
}
6264

63-
torch::jit::Module LowerModule(
64-
const torch::jit::Module& mod,
65-
std::string method_name,
66-
std::unordered_set<std::string> forced_fallback_modules) {
67-
passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules);
68-
LOG_GRAPH("After MLF notation pass: " << *mod.get_method(method_name).graph());
65+
torch::jit::Module LowerModule(const torch::jit::Module& mod, std::string method_name, const LowerInfo& lower_info) {
66+
std::unordered_set<std::string> forced_fallback_modules(
67+
lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end());
68+
if (forced_fallback_modules.size() > 0) {
69+
passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules);
70+
LOG_GRAPH("After MLF notation pass: " << *mod.get_method(method_name).graph());
71+
}
6972
auto mod_ = torch::jit::freeze_module(mod);
7073
LOG_GRAPH("After freeze: " << *mod_.get_method(method_name).graph());
7174
return mod_;
@@ -77,9 +80,7 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
7780
const LowerInfo& lower_info) {
7881
LOG_DEBUG(lower_info);
7982
LOG_GRAPH("Before lowering: " << *mod.get_method(method_name).graph());
80-
std::unordered_set<std::string> forced_fallback_modules(
81-
lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end());
82-
auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod, method_name, forced_fallback_modules);
83+
auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod, method_name, lower_info);
8384
auto g = lowered_mod.get_method(method_name).graph();
8485

8586
LOG_GRAPH("LibTorch Lowering");

Diff for: core/lowering/passes/module_fallback.cpp

+26-5
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,29 @@ void NotateModuleForFallback(
6161
LOG_GRAPH("Notated graph: " << *g);
6262
}
6363

64-
for (const auto sub_mod : mod.named_children()) {
65-
NotateModuleForFallback(sub_mod.value, sub_mod.name, method_name, forced_fallback_modules);
64+
if (mod.named_children().size() > 0) {
65+
for (const auto n : nodes) {
66+
std::string sub_method_name = "";
67+
if (n->kind() == torch::jit::prim::CallMethod) {
68+
sub_method_name = n->s(c10::Symbol::attr("name"));
69+
auto sub_mod_val = n->input(0);
70+
auto sub_mod_src_n = sub_mod_val->node();
71+
if (!sub_mod_src_n->hasAttributeS("name")) {
72+
LOG_GRAPH("Node: " << util::node_info(sub_mod_src_n) << " manages a module with no name, skipping");
73+
break;
74+
}
75+
auto sub_mod_name = sub_mod_src_n->s(c10::Symbol::attr("name"));
76+
for (const auto sub_mod : mod.named_children()) {
77+
// Theres probably a way to directly access the module we care about
78+
if (sub_mod.name == sub_mod_name) {
79+
LOG_GRAPH(
80+
"Looking at <module>.<method>() next: " << sub_mod_name << "." << sub_method_name
81+
<< "() (lowering.passes.NotateModuleForFallback)");
82+
NotateModuleForFallback(sub_mod.value, sub_mod.name, sub_method_name, forced_fallback_modules);
83+
}
84+
}
85+
}
86+
}
6687
}
6788
}
6889

@@ -74,23 +95,23 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_del
7495
auto n = *it;
7596
if (!mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) {
7697
if (n->s(c10::Symbol::attr("compilation_edge")) == "start") {
77-
LOG_DEBUG("Starting to mark new segmented block targeted for torch");
98+
LOG_GRAPH("Starting to mark new segmented block targeted for torch");
7899
mark.push(true);
79100
if (delete_delims) {
80101
it.destroyCurrent();
81102
}
82103
}
83104
} else if (mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) {
84105
if (n->s(c10::Symbol::attr("compilation_edge")) == "start") {
85-
LOG_DEBUG("Found the start of another segmented block targeted for torch while actively marking a block");
106+
LOG_GRAPH("Found the start of another segmented block targeted for torch while actively marking a block");
86107
mark.push(true);
87108
if (delete_delims) {
88109
it.destroyCurrent();
89110
}
90111
}
91112
} else if (mark.top() && n->kind() == torch::jit::prim::Exit && n->hasAttributeS("compilation_edge")) {
92113
if (n->s(c10::Symbol::attr("compilation_edge")) == "end") {
93-
LOG_DEBUG("Found the end of segmented block targeted for torch while actively marking a block");
114+
LOG_GRAPH("Found the end of segmented block targeted for torch while actively marking a block");
94115
mark.pop();
95116
if (delete_delims) {
96117
it.destroyCurrent();

0 commit comments

Comments
 (0)