Skip to content

Commit d39b918

Browse files
committed
feat: replace view with reshape during lowering
Signed-off-by: Ruoqian Guo <[email protected]>
1 parent 09afccb commit d39b918

File tree

6 files changed

+74
-0
lines changed

6 files changed

+74
-0
lines changed

Diff for: core/lowering/lowering.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
4545
passes::ReduceToOperation(g);
4646
passes::ReduceGelu(g);
4747
passes::RemoveContiguous(g);
48+
passes::ViewToReshape(g);
4849
passes::RemoveDropout(g);
4950
passes::LinearToAddMM(g);
5051
passes::Conv1DToConvolution(g);

Diff for: core/lowering/passes/BUILD

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ cc_library(
2020
"reduce_gelu.cpp",
2121
"remove_bn_dim_check.cpp",
2222
"remove_contiguous.cpp",
23+
"view_to_reshape.cpp",
2324
"remove_dropout.cpp",
2425
"remove_nops.cpp",
2526
"silu_to_sigmoid_multiplication.cpp",

Diff for: core/lowering/passes/passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph);
2424
void MarkNodesForFallback(std::shared_ptr<torch::jit::Graph>& g, bool delete_delims);
2525
void RemoveBNDimCheck(std::shared_ptr<torch::jit::Graph> graph);
2626
void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
27+
void ViewToReshape(std::shared_ptr<torch::jit::Graph>& graph);
2728
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
2829
void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
2930
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);

Diff for: core/lowering/passes/view_to_reshape.cpp

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
#include "core/util/prelude.h"
3+
4+
namespace torch_tensorrt {
5+
namespace core {
6+
namespace lowering {
7+
namespace passes {
8+
9+
void ViewToReshape(std::shared_ptr<torch::jit::Graph>& graph) {
10+
std::string view_pattern = R"IR(
11+
graph(%x, %1):
12+
%out : Tensor = aten::view(%x, %1)
13+
return (%out))IR";
14+
15+
std::string reshape_pattern = R"IR(
16+
graph(%x, %1):
17+
%out : Tensor = aten::reshape(%x, %1)
18+
return (%out))IR";
19+
20+
// replace aten::view with aten::reshape
21+
torch::jit::SubgraphRewriter map_view_to_reshape;
22+
map_view_to_reshape.RegisterRewritePattern(view_pattern, reshape_pattern);
23+
map_view_to_reshape.runOnGraph(graph);
24+
25+
LOG_GRAPH("Post lowering of aten::view -> " << *graph);
26+
}
27+
28+
} // namespace passes
29+
} // namespace lowering
30+
} // namespace core
31+
} // namespace torch_tensorrt

Diff for: tests/core/lowering/BUILD

+5
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ lowering_test(
5050
name = "test_remove_detach_pass",
5151
)
5252

53+
lowering_test(
54+
name = "test_view_to_reshape_pass",
55+
)
56+
5357
lowering_test(
5458
name = "test_operator_aliasing_pass",
5559
)
@@ -75,6 +79,7 @@ test_suite(
7579
":test_operator_aliasing_pass",
7680
":test_remove_contiguous_pass",
7781
":test_remove_detach_pass",
82+
":test_view_to_reshape_pass",
7883
":test_remove_dropout_pass",
7984
":test_reduce_to_pass",
8085
":test_reduce_gelu",

Diff for: tests/core/lowering/test_view_to_reshape_pass.cpp

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, ViewToReshapeCorrectly) {
10+
std::string source_graph = R"IR(
11+
graph(%x : Tensor, %1, %1.1):
12+
%0 : int = prim::Constant[value=0]()
13+
%2 : Tensor = aten::permute(%x, %1)
14+
%3 : Tensor = aten::contiguous(%2, %0)
15+
%4 : Tensor = aten::view(%3, %1.1)
16+
return (%4))IR";
17+
std::string target_graph = R"IR(
18+
graph(%x : Tensor, %1, %1.1):
19+
%0 : int = prim::Constant[value=0]()
20+
%2 : Tensor = aten::permute(%x, %1)
21+
%4 : Tensor = aten::reshape(%2, %1.1)
22+
return (%4))IR";
23+
24+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
25+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
26+
auto sg = std::make_shared<torch::jit::Graph>();
27+
torch::jit::parseIR(source_graph, &*sg);
28+
torch_tensorrt::core::lowering::passes::RemoveContiguous(sg);
29+
torch_tensorrt::core::lowering::passes::ViewToReshape(sg);
30+
31+
auto tg = std::make_shared<torch::jit::Graph>();
32+
torch::jit::parseIR(target_graph, &*tg);
33+
34+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
35+
}

0 commit comments

Comments
 (0)