Skip to content

Commit 318281d

Browse files
committed
test: add test case for rewrite input with params pass
Signed-off-by: Bo Wang <[email protected]>
1 parent 9fecda0 commit 318281d

File tree

3 files changed

+40
-3
lines changed

3 files changed

+40
-3
lines changed

core/lowering/passes/rewrite_inputs_with_params.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
#include "torch/csrc/jit/ir/constants.h"
21
#include "core/util/prelude.h"
3-
2+
#include "torch/csrc/jit/ir/constants.h"
43

54
namespace torch_tensorrt {
65
namespace core {
76
namespace lowering {
87
namespace passes {
98

10-
119
void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params) {
1210
auto input_size = g->inputs().size();
1311
auto param_it = params.rbegin();
@@ -33,6 +31,7 @@ void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph>& g, std::vector<
3331
}
3432
}
3533
}
34+
LOG_GRAPH("After RewriteInputsWithParams: " << *g);
3635
}
3736

3837
} // namespace passes

tests/core/lowering/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,10 @@ lowering_test(
8787
name = "test_unpack_reduce_ops",
8888
)
8989

90+
lowering_test(
91+
name = "test_rewrite_inputs_with_params",
92+
)
93+
9094
test_suite(
9195
name = "lowering_tests",
9296
tests = [
@@ -102,6 +106,7 @@ test_suite(
102106
":test_remove_detach_pass",
103107
":test_remove_dropout_pass",
104108
":test_remove_unnecessary_casts",
109+
":test_rewrite_inputs_with_params",
105110
":test_unpack_hardsigmoid",
106111
":test_unpack_hardswish",
107112
":test_unpack_reduce_ops",
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, RewriteInputsWithParamsCorrectly) {
10+
std::string source_graph = R"IR(
11+
graph(%x: Tensor, %y: Tensor, %1 : Int(1)):
12+
%out: Tensor = aten::sub(%x, %y, %1)
13+
return (%out))IR";
14+
std::string target_graph = R"IR(
15+
graph(%x: Tensor, %y : Tensor):
16+
%2 : int = prim::Constant[value=0]()
17+
%out: Tensor = aten::sub(%x, %y, %2)
18+
return (%out))IR";
19+
20+
torch_tensorrt::core::util::logging::get_logger().set_reportable_log_level(
21+
torch_tensorrt::core::util::logging::LogLevel::kGRAPH);
22+
auto sg = std::make_shared<torch::jit::Graph>();
23+
torch::jit::parseIR(source_graph, &*sg);
24+
25+
torch::jit::IValue param0 = torch::jit::IValue(0);
26+
std::vector<torch::jit::IValue> params{param0};
27+
torch_tensorrt::core::lowering::passes::RewriteInputsWithParams(sg, params);
28+
29+
auto tg = std::make_shared<torch::jit::Graph>();
30+
torch::jit::parseIR(target_graph, &*tg);
31+
32+
ASSERT_TRUE(!torch::jit::findPatternMatches(*tg, *sg).empty());
33+
}

0 commit comments

Comments
 (0)