1
+ #include < string>
2
+ #include " core/compiler.h"
3
+ #include " core/lowering/passes/passes.h"
4
+ #include " gtest/gtest.h"
5
+ #include " tests/util/util.h"
6
+ #include " torch/csrc/jit/ir/irparser.h"
7
+ #include " torch/csrc/jit/ir/subgraph_matcher.h"
8
+
9
+ TEST (LoweringPasses, RewriteInputsWithParamsCorrectly) {
10
+ std::string source_graph = R"IR(
11
+ graph(%x: Tensor, %y: Tensor, %1 : Int(1)):
12
+ %out: Tensor = aten::sub(%x, %y, %1)
13
+ return (%out))IR" ;
14
+ std::string target_graph = R"IR(
15
+ graph(%x: Tensor, %y : Tensor):
16
+ %2 : int = prim::Constant[value=0]()
17
+ %out: Tensor = aten::sub(%x, %y, %2)
18
+ return (%out))IR" ;
19
+
20
+ torch_tensorrt::core::util::logging::get_logger ().set_reportable_log_level (
21
+ torch_tensorrt::core::util::logging::LogLevel::kGRAPH );
22
+ auto sg = std::make_shared<torch::jit::Graph>();
23
+ torch::jit::parseIR (source_graph, &*sg);
24
+
25
+ torch::jit::IValue param0 = torch::jit::IValue (0 );
26
+ std::vector<torch::jit::IValue> params{param0};
27
+ torch_tensorrt::core::lowering::passes::RewriteInputsWithParams (sg, params);
28
+
29
+ auto tg = std::make_shared<torch::jit::Graph>();
30
+ torch::jit::parseIR (target_graph, &*tg);
31
+
32
+ ASSERT_TRUE (!torch::jit::findPatternMatches (*tg, *sg).empty ());
33
+ }
0 commit comments