@@ -37,7 +37,9 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
37
37
torch::jit::EliminateCommonSubexpression (g);
38
38
}
39
39
torch::jit::EliminateDeadCode (g);
40
- passes::MarkNodesForFallback (g, true );
40
+ if (lower_info.forced_fallback_modules .size () > 0 ) {
41
+ passes::MarkNodesForFallback (g, true );
42
+ }
41
43
passes::UnpackHardSwish (g);
42
44
passes::EliminateExceptionOrPassPattern (g);
43
45
passes::ReduceToOperation (g);
@@ -60,12 +62,13 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
60
62
LOG_GRAPH (*g);
61
63
}
62
64
63
- torch::jit::Module LowerModule (
64
- const torch::jit::Module& mod,
65
- std::string method_name,
66
- std::unordered_set<std::string> forced_fallback_modules) {
67
- passes::NotateModuleForFallback (mod, " " , method_name, forced_fallback_modules);
68
- LOG_GRAPH (" After MLF notation pass: " << *mod.get_method (method_name).graph ());
65
+ torch::jit::Module LowerModule (const torch::jit::Module& mod, std::string method_name, const LowerInfo& lower_info) {
66
+ std::unordered_set<std::string> forced_fallback_modules (
67
+ lower_info.forced_fallback_modules .begin (), lower_info.forced_fallback_modules .end ());
68
+ if (forced_fallback_modules.size () > 0 ) {
69
+ passes::NotateModuleForFallback (mod, " " , method_name, forced_fallback_modules);
70
+ LOG_GRAPH (" After MLF notation pass: " << *mod.get_method (method_name).graph ());
71
+ }
69
72
auto mod_ = torch::jit::freeze_module (mod);
70
73
LOG_GRAPH (" After freeze: " << *mod_.get_method (method_name).graph ());
71
74
return mod_;
@@ -77,9 +80,7 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
77
80
const LowerInfo& lower_info) {
78
81
LOG_DEBUG (lower_info);
79
82
LOG_GRAPH (" Before lowering: " << *mod.get_method (method_name).graph ());
80
- std::unordered_set<std::string> forced_fallback_modules (
81
- lower_info.forced_fallback_modules .begin (), lower_info.forced_fallback_modules .end ());
82
- auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule (mod, method_name, forced_fallback_modules);
83
+ auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule (mod, method_name, lower_info);
83
84
auto g = lowered_mod.get_method (method_name).graph ();
84
85
85
86
LOG_GRAPH (" LibTorch Lowering" );
0 commit comments