1
+ #include " core/lowering/passes/passes.h"
2
+ #include " gtest/gtest.h"
3
+ #include " torch/csrc/jit/ir/irparser.h"
4
+
5
+ TEST (LoweringPasses, EliminateExceptionOrPassPattern_Block0) {
6
+ // parseIR does not support " = prim::If(%51)" with no return value
7
+ /* std::string source_ir = R"IR(graph(%x.1 : Tensor, %y.1 : Tensor):
8
+ %3 : NoneType = prim::Constant()
9
+ %4 : int = prim::Constant[value=0]()
10
+ %mod_list.1 : Tensor[] = prim::ListConstruct(%x.1)
11
+ %47 : Tensor = aten::sum(%x.1, %3)
12
+ %49 : Tensor = aten::sum(%y.1, %3)
13
+ %50 : Tensor = aten::gt(%47, %49)
14
+ %51 : bool = aten::Bool(%50)
15
+ = prim::If(%51)
16
+ block0():
17
+ = prim::RaiseException(%45)
18
+ -> ()
19
+ block1():
20
+ -> ()
21
+ %z.1 : Tensor = aten::cat(%mod_list.1, %4)
22
+ return (%z.1))IR";*/
23
+
24
+ auto g = std::make_shared<torch::jit::Graph>();
25
+ auto x = g->insertInput (0 , " x" );
26
+ auto y = g->insertInput (1 , " y" );
27
+ torch::jit::IValue zero (0 );
28
+ auto zero_const_val = g->insertConstant (zero);
29
+ auto none_const_val = g->insertConstant (torch::jit::IValue ());
30
+ torch::jit::IValue except (" EXCEPTION" );
31
+ auto except_val = g->insertConstant (except);
32
+ auto list_node = g->createList (x->type (), torch::jit::ArrayRef<torch::jit::Value*>(x));
33
+ g->insertNode (list_node);
34
+ auto sum_x_node = g->create (torch::jit::aten::sum, {x, none_const_val});
35
+ g->insertNode (sum_x_node);
36
+ auto sum_y_node = g->create (torch::jit::aten::sum, {y, none_const_val});
37
+ g->insertNode (sum_y_node);
38
+ auto gt_node = g->create (torch::jit::aten::gt, {sum_x_node->output (), sum_y_node->output ()});
39
+ g->insertNode (gt_node);
40
+ auto bool_node = g->create (torch::jit::aten::Bool, {gt_node->output ()});
41
+ bool_node->output ()->setType (torch::jit::BoolType::get ());
42
+ g->insertNode (bool_node);
43
+ auto if_node = g->create (torch::jit::prim::If, {bool_node->output ()}, 0 );
44
+ auto if_block0 = if_node->addBlock ();
45
+ auto exception_node = g->create (torch::jit::prim::RaiseException, {except_val}, 0 );
46
+ if_block0->appendNode (exception_node);
47
+ auto if_block1 = if_node->addBlock ();
48
+ g->insertNode (if_node);
49
+ auto cat_node = g->create (torch::jit::aten::cat, {list_node->output (), zero_const_val});
50
+ g->insertNode (cat_node);
51
+ g->registerOutput (cat_node->output ());
52
+
53
+ torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern (g);
54
+ for (auto node : g->nodes ()) {
55
+ EXPECT_NE (node, if_node);
56
+ }
57
+ }
58
+
59
+ TEST (LoweringPasses, EliminateExceptionOrPassPattern_Block1) {
60
+ // parseIR does not support " = prim::If(%51)" with no return value
61
+ /* std::string source_ir = R"IR(graph(%x.1 : Tensor, %y.1 : Tensor):
62
+ %3 : NoneType = prim::Constant()
63
+ %4 : int = prim::Constant[value=0]()
64
+ %mod_list.1 : Tensor[] = prim::ListConstruct(%x.1)
65
+ %47 : Tensor = aten::sum(%x.1, %3)
66
+ %49 : Tensor = aten::sum(%y.1, %3)
67
+ %50 : Tensor = aten::gt(%47, %49)
68
+ %51 : bool = aten::Bool(%50)
69
+ = prim::If(%51)
70
+ block0():
71
+ -> ()
72
+ block1():
73
+ = prim::RaiseException(%45)
74
+ -> ()
75
+ %z.1 : Tensor = aten::cat(%mod_list.1, %4)
76
+ return (%z.1))IR";*/
77
+
78
+ auto g = std::make_shared<torch::jit::Graph>();
79
+ auto x = g->insertInput (0 , " x" );
80
+ auto y = g->insertInput (1 , " y" );
81
+ torch::jit::IValue zero (0 );
82
+ auto zero_const_val = g->insertConstant (zero);
83
+ auto none_const_val = g->insertConstant (torch::jit::IValue ());
84
+ torch::jit::IValue except (" EXCEPTION" );
85
+ auto except_val = g->insertConstant (except);
86
+ auto list_node = g->createList (x->type (), torch::jit::ArrayRef<torch::jit::Value*>(x));
87
+ g->insertNode (list_node);
88
+ auto sum_x_node = g->create (torch::jit::aten::sum, {x, none_const_val});
89
+ g->insertNode (sum_x_node);
90
+ auto sum_y_node = g->create (torch::jit::aten::sum, {y, none_const_val});
91
+ g->insertNode (sum_y_node);
92
+ auto gt_node = g->create (torch::jit::aten::gt, {sum_x_node->output (), sum_y_node->output ()});
93
+ g->insertNode (gt_node);
94
+ auto bool_node = g->create (torch::jit::aten::Bool, {gt_node->output ()});
95
+ bool_node->output ()->setType (torch::jit::BoolType::get ());
96
+ g->insertNode (bool_node);
97
+ auto if_node = g->create (torch::jit::prim::If, {bool_node->output ()}, 0 );
98
+ auto if_block0 = if_node->addBlock ();
99
+ auto if_block1 = if_node->addBlock ();
100
+ auto exception_node = g->create (torch::jit::prim::RaiseException, {except_val}, 0 );
101
+ if_block1->appendNode (exception_node);
102
+ g->insertNode (if_node);
103
+ auto cat_node = g->create (torch::jit::aten::cat, {list_node->output (), zero_const_val});
104
+ g->insertNode (cat_node);
105
+ g->registerOutput (cat_node->output ());
106
+
107
+ torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern (g);
108
+ for (auto node : g->nodes ()) {
109
+ EXPECT_NE (node, if_node);
110
+ }
111
+ }
112
+
113
+ TEST (LoweringPasses, EliminateExceptionOrPassPattern_Negative) {
114
+ // parseIR does not support " = prim::If(%51)" with no return value
115
+ /* std::string source_ir = R"IR(graph(%x.1 : Tensor, %y.1 : Tensor):
116
+ %3 : NoneType = prim::Constant()
117
+ %4 : int = prim::Constant[value=0]()
118
+ %mod_list.1 : Tensor[] = prim::ListConstruct(%x.1)
119
+ %47 : Tensor = aten::sum(%x.1, %3)
120
+ %49 : Tensor = aten::sum(%y.1, %3)
121
+ %50 : Tensor = aten::gt(%47, %49)
122
+ %51 : bool = aten::Bool(%50)
123
+ = prim::If(%51)
124
+ block0():
125
+ %10 : Tensor[] = aten::append(%mod_list.1, %y.1)
126
+ -> ()
127
+ block1():
128
+ -> ()
129
+ %z.1 : Tensor = aten::cat(%mod_list.1, %4)
130
+ return (%z.1))IR";*/
131
+
132
+ auto g = std::make_shared<torch::jit::Graph>();
133
+ auto x = g->insertInput (0 , " x" );
134
+ auto y = g->insertInput (1 , " y" );
135
+ torch::jit::IValue zero (0 );
136
+ auto zero_const_val = g->insertConstant (zero);
137
+ auto none_const_val = g->insertConstant (torch::jit::IValue ());
138
+ auto list_node = g->createList (x->type (), torch::jit::ArrayRef<torch::jit::Value*>(x));
139
+ g->insertNode (list_node);
140
+ auto sum_x_node = g->create (torch::jit::aten::sum, {x, none_const_val});
141
+ g->insertNode (sum_x_node);
142
+ auto sum_y_node = g->create (torch::jit::aten::sum, {y, none_const_val});
143
+ g->insertNode (sum_y_node);
144
+ auto gt_node = g->create (torch::jit::aten::gt, {sum_x_node->output (), sum_y_node->output ()});
145
+ g->insertNode (gt_node);
146
+ auto bool_node = g->create (torch::jit::aten::Bool, {gt_node->output ()});
147
+ bool_node->output ()->setType (torch::jit::BoolType::get ());
148
+ g->insertNode (bool_node);
149
+ auto if_node = g->create (torch::jit::prim::If, {bool_node->output ()}, 0 );
150
+ auto if_block0 = if_node->addBlock ();
151
+ auto append_node = g->create (torch::jit::aten::append, {list_node->output (), y});
152
+ if_block0->appendNode (append_node);
153
+ auto if_block1 = if_node->addBlock ();
154
+ g->insertNode (if_node);
155
+ auto cat_node = g->create (torch::jit::aten::cat, {list_node->output (), zero_const_val});
156
+ g->insertNode (cat_node);
157
+ g->registerOutput (cat_node->output ());
158
+
159
+ torch_tensorrt::core::lowering::passes::EliminateExceptionOrPassPattern (g);
160
+ int if_count = 0 ;
161
+ for (auto node : g->nodes ()) {
162
+ if (node == if_node) {
163
+ if_count++;
164
+ }
165
+ }
166
+ EXPECT_EQ (1 , if_count);
167
+ }
0 commit comments