Skip to content

Commit 0266f41

Browse files
committed
fix: Move some lowering passes to graph level logging
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent ef2732a commit 0266f41

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

Diff for: core/lowering/passes/module_fallback.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ void NotateModuleForFallback(
3939
if (n->kind() == torch::jit::prim::GetAttr) {
4040
auto out_type = unmangle_cls_name(c10::toString(n->output(0)->type()));
4141
if (forced_fallback_modules.find(out_type) != forced_fallback_modules.end()) {
42-
LOG_DEBUG(
42+
LOG_GRAPH(
4343
"Notating module for fallback: " << n->s(c10::attr::name) << " (" << out_type << ") [owner: " << mod_name
4444
<< " (" << cls_name << ")]");
4545
auto uses = n->output(0)->uses();
@@ -58,7 +58,7 @@ void NotateModuleForFallback(
5858
}
5959

6060
if (changed_mod) {
61-
LOG_DEBUG("Notated graph: " << *g);
61+
LOG_GRAPH("Notated graph: " << *g);
6262
}
6363

6464
for (const auto sub_mod : mod.named_children()) {
@@ -106,10 +106,10 @@ void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_del
106106
}
107107
}
108108

109-
LOG_DEBUG("After marking operations for torch fallback: " << *g);
109+
LOG_GRAPH("After marking operations for torch fallback: " << *g);
110110
}
111111

112112
} // namespace passes
113113
} // namespace lowering
114114
} // namespace core
115-
} // namespace trtorch
115+
} // namespace trtorch

Diff for: core/lowering/passes/remove_nops.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct NOPRemoval {
2222
void run() {
2323
removeNode(graph_->block(), "aten::detach");
2424
torch::jit::EliminateDeadCode(graph_);
25-
LOG_DEBUG("RemoveNOPs - Note: Removing operators that have no meaning in TRT");
25+
LOG_GRAPH("RemoveNOPs - Note: Removing operators that have no meaning in TRT");
2626
LOG_GRAPH("Post aten::detach removal: " << *graph_);
2727
}
2828

Diff for: core/lowering/passes/unpack_var.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph) {
4242
torch::jit::SubgraphRewriter var_rewriter;
4343
var_rewriter.RegisterRewritePattern(var_pattern, unpacked_pattern);
4444
var_rewriter.runOnGraph(graph);
45-
LOG_DEBUG("Post unpack var: " << *graph);
45+
LOG_GRAPH("Post unpack var: " << *graph);
4646
}
4747

4848
} // namespace passes

0 commit comments

Comments
 (0)