File tree 2 files changed +16
-1
lines changed
2 files changed +16
-1
lines changed Original file line number Diff line number Diff line change @@ -22,7 +22,8 @@ const std::unordered_set<std::string>& get_non_convertable_nodes() {
22
22
" prim::GetAttr" ,
23
23
" prim::CallMethod" ,
24
24
" prim::Drop" ,
25
- " aten:dropout" ,
25
+ " aten::dropout" ,
26
+ " aten::dropout_"
26
27
};
27
28
return nonconvertable_nodes;
28
29
}
Original file line number Diff line number Diff line change @@ -20,6 +20,20 @@ void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
20
20
remove_dropout.RegisterRewritePattern (
21
21
dropout_pattern, no_dropout_pattern);
22
22
remove_dropout.runOnGraph (graph);
23
+
24
+ std::string dropout_inplace_pattern = R"IR(
25
+ graph(%input, %4, %5):
26
+ %6 = aten::dropout_(%input, %4, %5)
27
+ return (%6))IR" ;
28
+ std::string no_dropout_inplace_pattern = R"IR(
29
+ graph(%input, %4, %5):
30
+ return (%input))IR" ;
31
+
32
+ torch::jit::SubgraphRewriter remove_dropout_inplace_pattern;
33
+ remove_dropout_inplace_pattern.RegisterRewritePattern (
34
+ dropout_inplace_pattern, no_dropout_inplace_pattern);
35
+ remove_dropout_inplace_pattern.runOnGraph (graph);
36
+
23
37
LOG_GRAPH (" Post remove dropout: " << *graph);
24
38
}
25
39
You can’t perform that action at this time.
0 commit comments