Skip to content

Commit 2e04ce5

Browse files
committed
feat(//core/lowering): Adding two passes, one to delimit and one to mark
ops to fallback Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent ad07645 commit 2e04ce5

File tree

4 files changed

+108
-0
lines changed

4 files changed

+108
-0
lines changed

Diff for: core/lowering/passes/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cc_library(
1515
"exception_elimination.cpp",
1616
"fuse_addmm_branches.cpp",
1717
"linear_to_addmm.cpp",
18+
"module_fallback.cpp",
1819
"op_aliasing.cpp",
1920
"reduce_to.cpp",
2021
"remove_bn_dim_check.cpp",

Diff for: core/lowering/passes/module_fallback.cpp

+103
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
#include <stack>
2+
#include <unordered_set>
3+
4+
#include "core/lowering/passes/passes.h"
5+
#include "core/util/prelude.h"
6+
7+
namespace trtorch {
8+
namespace core {
9+
namespace lowering {
10+
namespace passes {
11+
12+
std::string unmangle_cls_name(const std::string& name) {
13+
auto unmangled = name;
14+
15+
std::size_t torch_prefix = unmangled.find("__torch__");
16+
if (torch_prefix != std::string::npos) {
17+
unmangled.erase(torch_prefix, 10);
18+
}
19+
20+
std::size_t mangle_pos = unmangled.find("___torch_mangle_");
21+
if (mangle_pos != std::string::npos) {
22+
unmangled.erase(mangle_pos, 21);
23+
}
24+
return unmangled;
25+
}
26+
27+
void NotateModuleForFallback(const torch::jit::Module& mod, std::string mod_name, const std::string& method_name, std::unordered_set<std::string> forced_fallback_modules) {
28+
auto cls_name = unmangle_cls_name(mod.type()->name()->qualifiedName());
29+
auto g = mod.get_method(method_name).graph();
30+
31+
auto nodes = g->block()->nodes();
32+
bool changed_mod = false;
33+
for (const auto n : nodes) {
34+
if (n->kind() == torch::jit::prim::GetAttr) {
35+
auto out_type = unmangle_cls_name(c10::toString(n->output(0)->type()));
36+
if (forced_fallback_modules.find(out_type) != forced_fallback_modules.end()) {
37+
LOG_DEBUG("Marking module for fallback: " << n->s(c10::attr::name) << " (" << out_type << ") [owner: " << mod_name << " (" << cls_name << ")]");
38+
auto uses = n->output(0)->uses();
39+
for (const auto u : uses) {
40+
auto user = u.user;
41+
auto delim_start_n = g->create(torch::jit::prim::Enter, 0);
42+
delim_start_n->s_(c10::Symbol::attr("compilation_edge"), "start");
43+
auto num_end_outs = user->outputs().size();
44+
auto delim_end_n = g->create(torch::jit::prim::Exit, 0);
45+
delim_end_n->s_(c10::Symbol::attr("compilation_edge"), "end");
46+
delim_start_n->insertBefore(user);
47+
delim_end_n->insertAfter(user);
48+
}
49+
changed_mod = true;
50+
}
51+
}
52+
}
53+
54+
if (changed_mod) {
55+
LOG_DEBUG(*g);
56+
}
57+
58+
for (const auto sub_mod : mod.named_children()) {
59+
NotateModuleForFallback(sub_mod.value, sub_mod.name, method_name, forced_fallback_modules);
60+
}
61+
}
62+
63+
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g) {
64+
auto b = g->block();
65+
66+
std::stack<bool> mark = std::stack<bool>({false});
67+
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
68+
auto n = *it;
69+
if (!mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) {
70+
if (n->s(c10::Symbol::attr("compilation_edge")) == "start") {
71+
LOG_DEBUG("Starting to mark new segmented targeted for torch");
72+
mark.push(true);
73+
it.destroyCurrent();
74+
}
75+
} else if (mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) {
76+
if(n->s(c10::Symbol::attr("compilation_edge")) == "start") {
77+
LOG_DEBUG("Found the start of another segmented block targeted for torch while actively marking a block");
78+
mark.push(true);
79+
it.destroyCurrent();
80+
}
81+
} else if (mark.top() && n->kind() == torch::jit::prim::Exit && n->hasAttributeS("compilation_edge")) {
82+
if(n->s(c10::Symbol::attr("compilation_edge")) == "end") {
83+
LOG_DEBUG("Found the end of segmented block targeted for torch while actively marking a block");
84+
mark.pop();
85+
it.destroyCurrent();
86+
}
87+
} else if (!mark.top() && n->kind() == torch::jit::prim::Exit && n->hasAttributeS("compilation_edge")) {
88+
if(n->s(c10::Symbol::attr("compilation_edge")) == "end") {
89+
LOG_WARNING("Found the end of segmented block targeted for torch while not actively marking a block");
90+
}
91+
} else if (mark.top()) {
92+
LOG_GRAPH("Marking " << util::node_info(n) << " to run in PyTorch");
93+
n->i_(c10::Symbol::attr("to_compile"), (int64_t) false);
94+
}
95+
}
96+
97+
LOG_GRAPH("Post marking ops for pytorch execution: " << *g);
98+
}
99+
100+
} // Namespace passes
101+
} // namespace lowering
102+
} // namespace core
103+
} // namespace trtorch

Diff for: core/lowering/passes/passes.h

+3
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@ namespace core {
77
namespace lowering {
88
namespace passes {
99

10+
void NotateModuleForFallback(const torch::jit::Module& mod, std::string mod_name, const std::string& method_name, std::unordered_set<std::string> forced_fallback_modules);
11+
1012
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1113
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1214
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
1315
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph);
1416
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
1517
void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph);
18+
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g);
1619
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
1720
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
1821
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);

Diff for: core/lowering/register_trt_placeholder_ops.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "torch/csrc/jit/runtime/custom_operator.h"
2+
#include "torch/library.h"
23

34
namespace torch {
45
namespace jit {

0 commit comments

Comments
 (0)