Skip to content

Commit 0f63ffa

Browse files
committed
feat(aten::to): Remove remaining typecast operators (should be a very
late pass to run) Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 670817c commit 0f63ffa

File tree

5 files changed

+62
-3
lines changed

5 files changed

+62
-3
lines changed

Diff for: core/lowering/lowering.cpp

+6-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
12
#include "torch/csrc/jit/passes/dead_code_elimination.h"
23
#include "torch/csrc/jit/passes/fuse_linear.h"
34
#include "torch/csrc/jit/passes/freeze_module.h"
5+
#include "torch/csrc/jit/passes/loop_unrolling.h"
46
#include "torch/csrc/jit/passes/lower_graph.h"
57
#include "torch/csrc/jit/passes/lower_tuples.h"
68
#include "torch/csrc/jit/passes/quantization.h"
@@ -30,11 +32,13 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
3032
passes::FuseFlattenLinear(g);
3133
passes::Conv2DToConvolution(g);
3234
passes::FuseAddMMBranches(g);
35+
torch::jit::EliminateCommonSubexpression(g);
36+
torch::jit::UnrollLoops(g);
37+
torch::jit::EliminateCommonSubexpression(g);
3338
passes::UnpackAddMM(g);
3439
//passes::UnpackBatchNorm(g);
3540
passes::UnpackLogSoftmax(g);
36-
//passes::RemoveDimExeception(g);
37-
//irfusers::UnpackBatchNorm(g);
41+
passes::RemoveTo(g);
3842
torch::jit::EliminateDeadCode(g);
3943
LOG_GRAPH(*g);
4044
}

Diff for: core/lowering/passes/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ cc_library(
1919
"fuse_flatten_linear.cpp",
2020
"remove_contiguous.cpp",
2121
"remove_dropout.cpp",
22+
"remove_to.cpp",
2223
"unpack_addmm.cpp",
2324
"unpack_batch_norm.cpp",
2425
"unpack_log_softmax.cpp",

Diff for: core/lowering/passes/passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
1313
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
1414
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
1515
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
16+
void RemoveTo(std::shared_ptr<torch::jit::Graph> graph);
1617
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
1718
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
1819
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);

Diff for: core/lowering/passes/remove_dropout.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
1616
graph(%input, %4, %5):
1717
return (%input))IR";
1818

19-
// replace matmul + add pattern to linear
2019
torch::jit::SubgraphRewriter remove_dropout;
2120
remove_dropout.RegisterRewritePattern(
2221
dropout_pattern, no_dropout_pattern);

Diff for: core/lowering/passes/remove_to.cpp

+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
#include "torch/csrc/jit/passes/guard_elimination.h"
2+
#include "torch/csrc/jit/ir/alias_analysis.h"
3+
#include "torch/csrc/jit/jit_log.h"
4+
#include "torch/csrc/jit/passes/constant_propagation.h"
5+
#include "torch/csrc/jit/passes/peephole.h"
6+
#include "torch/csrc/jit/runtime/graph_executor.h"
7+
#include "torch/csrc/jit/passes/dead_code_elimination.h"
8+
9+
#include "core/util/prelude.h"
10+
11+
#include <vector>
12+
13+
namespace trtorch {
14+
namespace core {
15+
namespace lowering {
16+
namespace passes {
17+
namespace {
18+
using namespace torch::jit;
19+
struct ToRemoval {
20+
ToRemoval(std::shared_ptr<Graph> graph)
21+
: graph_(std::move(graph)) {}
22+
23+
void run() {
24+
findTo(graph_->block());
25+
torch::jit::EliminateDeadCode(graph_);
26+
LOG_DEBUG("RemoveTo - Note: Removing remaining aten::to operators, if type casts need to be preserved, add a pass before this pass is run");
27+
LOG_GRAPH("Post aten::to removal: " << *graph_);
28+
}
29+
30+
private:
31+
void findTo(Block* b) {
32+
for (auto it = b->nodes().begin(); it != b->nodes().end(); it++) {
33+
auto n = *it;
34+
if (n->kind() == c10::Symbol::fromQualString("aten::to")) {
35+
LOG_GRAPH("Found that node " << *n << " is an to node (RemoveTo)" << std::endl);
36+
n->outputs()[0]->replaceAllUsesWith(n->inputs()[1]);
37+
it.destroyCurrent();
38+
}
39+
}
40+
}
41+
42+
std::shared_ptr<Graph> graph_;
43+
};
44+
} // namespace
45+
46+
void RemoveTo(std::shared_ptr<Graph> graph) {
47+
ToRemoval tr(std::move(graph));
48+
tr.run();
49+
}
50+
51+
} // namespace passes
52+
} // namespace lowering
53+
} // namespace core
54+
} // namespace trtorch

0 commit comments

Comments
 (0)