Skip to content

Commit 7aa57c3

Browse files
committed
feat(aten::dropout_): Remove inplace dropout
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 19c91f2 commit 7aa57c3

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

Diff for: core/conversion/conversion_blacklist.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ const std::unordered_set<std::string>& get_non_convertable_nodes() {
2222
"prim::GetAttr",
2323
"prim::CallMethod",
2424
"prim::Drop",
25-
"aten:dropout",
25+
"aten::dropout",
26+
"aten::dropout_"
2627
};
2728
return nonconvertable_nodes;
2829
}

Diff for: core/lowering/passes/remove_dropout.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,20 @@ void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
2020
remove_dropout.RegisterRewritePattern(
2121
dropout_pattern, no_dropout_pattern);
2222
remove_dropout.runOnGraph(graph);
23+
24+
std::string dropout_inplace_pattern = R"IR(
25+
graph(%input, %4, %5):
26+
%6 = aten::dropout_(%input, %4, %5)
27+
return (%6))IR";
28+
std::string no_dropout_inplace_pattern = R"IR(
29+
graph(%input, %4, %5):
30+
return (%input))IR";
31+
32+
torch::jit::SubgraphRewriter remove_dropout_inplace_pattern;
33+
remove_dropout_inplace_pattern.RegisterRewritePattern(
34+
dropout_inplace_pattern, no_dropout_inplace_pattern);
35+
remove_dropout_inplace_pattern.runOnGraph(graph);
36+
2337
LOG_GRAPH("Post remove dropout: " << *graph);
2438
}
2539

0 commit comments

Comments
 (0)