Skip to content

Commit 2c5c0d5

Browse files
committed
feat(conv2d_to_convolution): A pass to map aten::conv2d to _convolution
Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 272ef40 commit 2c5c0d5

File tree

4 files changed

+37
-0
lines changed

4 files changed

+37
-0
lines changed

Diff for: core/lowering/lowering.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
2525
torch::jit::FuseLinear(g);
2626
passes::RemoveDropout(g);
2727
passes::FuseFlattenLinear(g);
28+
passes::Conv2DToConvolution(g);
2829
passes::UnpackAddMM(g);
2930
passes::UnpackLogSoftmax(g);
3031
//passes::RemoveDimExeception(g);

Diff for: core/lowering/passes/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ cc_library(
66
"passes.h",
77
],
88
srcs = [
9+
"conv2d_to_convolution.cpp",
910
"exception_elimination.cpp",
1011
"fuse_flatten_linear.cpp",
1112
"remove_dropout.cpp",

Diff for: core/lowering/passes/conv2d_to_convolution.cpp

+34
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string conv2d_pattern = R"IR(
12+
graph(%x, %w, %b, %s, %p, %d, %g):
13+
%4 : Tensor = aten::conv2d(%x, %w, %b, %s, %p, %d, %g)
14+
return (%4))IR";
15+
std::string convolution_pattern = R"IR(
16+
graph(%x, %w, %b, %s, %p, %d, %g):
17+
%1 : bool = prim::Constant[value=1]()
18+
%2 : int[] = prim::Constant[value=[0, 0]]()
19+
%3 : bool = prim::Constant[value=0]()
20+
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %3)
21+
return (%4))IR";;
22+
23+
// replace matmul + add pattern to linear
24+
torch::jit::SubgraphRewriter map_conv2d_to_convolution;
25+
map_conv2d_to_convolution.RegisterRewritePattern(
26+
conv2d_pattern, convolution_pattern);
27+
map_conv2d_to_convolution.runOnGraph(graph);
28+
LOG_GRAPH("Post map conv2d -> _convolution: " << *graph);
29+
}
30+
31+
} // namespace passes
32+
} // namespace lowering
33+
} // namespace core
34+
} // namespace trtorch

Diff for: core/lowering/passes/passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ namespace core {
77
namespace lowering {
88
namespace passes {
99

10+
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1011
void FuseFlattenLinear(std::shared_ptr<torch::jit::Graph>& graph);
1112
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
1213
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);

0 commit comments

Comments
 (0)