File tree 5 files changed +62
-3
lines changed
5 files changed +62
-3
lines changed Original file line number Diff line number Diff line change
1
+ #include " torch/csrc/jit/passes/common_subexpression_elimination.h"
1
2
#include " torch/csrc/jit/passes/dead_code_elimination.h"
2
3
#include " torch/csrc/jit/passes/fuse_linear.h"
3
4
#include " torch/csrc/jit/passes/freeze_module.h"
5
+ #include " torch/csrc/jit/passes/loop_unrolling.h"
4
6
#include " torch/csrc/jit/passes/lower_graph.h"
5
7
#include " torch/csrc/jit/passes/lower_tuples.h"
6
8
#include " torch/csrc/jit/passes/quantization.h"
@@ -30,11 +32,13 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
30
32
passes::FuseFlattenLinear (g);
31
33
passes::Conv2DToConvolution (g);
32
34
passes::FuseAddMMBranches (g);
35
+ torch::jit::EliminateCommonSubexpression (g);
36
+ torch::jit::UnrollLoops (g);
37
+ torch::jit::EliminateCommonSubexpression (g);
33
38
passes::UnpackAddMM (g);
34
39
// passes::UnpackBatchNorm(g);
35
40
passes::UnpackLogSoftmax (g);
36
- // passes::RemoveDimExeception(g);
37
- // irfusers::UnpackBatchNorm(g);
41
+ passes::RemoveTo (g);
38
42
torch::jit::EliminateDeadCode (g);
39
43
LOG_GRAPH (*g);
40
44
}
Original file line number Diff line number Diff line change @@ -19,6 +19,7 @@ cc_library(
19
19
"fuse_flatten_linear.cpp" ,
20
20
"remove_contiguous.cpp" ,
21
21
"remove_dropout.cpp" ,
22
+ "remove_to.cpp" ,
22
23
"unpack_addmm.cpp" ,
23
24
"unpack_batch_norm.cpp" ,
24
25
"unpack_log_softmax.cpp" ,
Original file line number Diff line number Diff line change @@ -13,6 +13,7 @@ void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
13
13
void EliminateExceptionOrPassPattern (std::shared_ptr<torch::jit::Graph> graph);
14
14
void RemoveContiguous (std::shared_ptr<torch::jit::Graph>& graph);
15
15
void RemoveDropout (std::shared_ptr<torch::jit::Graph>& graph);
16
+ void RemoveTo (std::shared_ptr<torch::jit::Graph> graph);
16
17
void UnpackAddMM (std::shared_ptr<torch::jit::Graph>& graph);
17
18
void UnpackBatchNorm (std::shared_ptr<torch::jit::Graph>& graph);
18
19
void UnpackLogSoftmax (std::shared_ptr<torch::jit::Graph>& graph);
Original file line number Diff line number Diff line change @@ -16,7 +16,6 @@ void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
16
16
graph(%input, %4, %5):
17
17
return (%input))IR" ;
18
18
19
- // replace matmul + add pattern to linear
20
19
torch::jit::SubgraphRewriter remove_dropout;
21
20
remove_dropout.RegisterRewritePattern (
22
21
dropout_pattern, no_dropout_pattern);
Original file line number Diff line number Diff line change
1
+ #include " torch/csrc/jit/passes/guard_elimination.h"
2
+ #include " torch/csrc/jit/ir/alias_analysis.h"
3
+ #include " torch/csrc/jit/jit_log.h"
4
+ #include " torch/csrc/jit/passes/constant_propagation.h"
5
+ #include " torch/csrc/jit/passes/peephole.h"
6
+ #include " torch/csrc/jit/runtime/graph_executor.h"
7
+ #include " torch/csrc/jit/passes/dead_code_elimination.h"
8
+
9
+ #include " core/util/prelude.h"
10
+
11
+ #include < vector>
12
+
13
+ namespace trtorch {
14
+ namespace core {
15
+ namespace lowering {
16
+ namespace passes {
17
+ namespace {
18
+ using namespace torch ::jit;
19
+ struct ToRemoval {
20
+ ToRemoval (std::shared_ptr<Graph> graph)
21
+ : graph_(std::move(graph)) {}
22
+
23
+ void run () {
24
+ findTo (graph_->block ());
25
+ torch::jit::EliminateDeadCode (graph_);
26
+ LOG_DEBUG (" RemoveTo - Note: Removing remaining aten::to operators, if type casts need to be preserved, add a pass before this pass is run" );
27
+ LOG_GRAPH (" Post aten::to removal: " << *graph_);
28
+ }
29
+
30
+ private:
31
+ void findTo (Block* b) {
32
+ for (auto it = b->nodes ().begin (); it != b->nodes ().end (); it++) {
33
+ auto n = *it;
34
+ if (n->kind () == c10::Symbol::fromQualString (" aten::to" )) {
35
+ LOG_GRAPH (" Found that node " << *n << " is an to node (RemoveTo)" << std::endl);
36
+ n->outputs ()[0 ]->replaceAllUsesWith (n->inputs ()[1 ]);
37
+ it.destroyCurrent ();
38
+ }
39
+ }
40
+ }
41
+
42
+ std::shared_ptr<Graph> graph_;
43
+ };
44
+ } // namespace
45
+
46
+ void RemoveTo (std::shared_ptr<Graph> graph) {
47
+ ToRemoval tr (std::move (graph));
48
+ tr.run ();
49
+ }
50
+
51
+ } // namespace passes
52
+ } // namespace lowering
53
+ } // namespace core
54
+ } // namespace trtorch
You can’t perform that action at this time.
0 commit comments