@@ -108,6 +108,57 @@ void EliminateExceptionOrPassPattern(std::shared_ptr<Graph> graph) {
108
108
}
109
109
}
110
110
111
+ /*
112
+ Below is a fork of the torch::jit::EliminateExceptions pass, with node replacement
113
+ using replaceAllUsesDominatedByNodeWith instead of replaceAllUsesWith,
114
+ so as to not invalidate the IR in challenging cases, such as nested Ifs
115
+
116
+ Original Source from which it was adapted:
117
+ https://github.com/pytorch/pytorch/blob/c29ab84115f40614d04e4557ea2e1ac40b7aa75c/torch/csrc/jit/passes/remove_exceptions.cpp
118
+ */
119
+
120
+ bool certainlyThrows (Block* block) {
121
+ // A block certainly throws an exception if it contains
122
+ // the prim::RaiseException operation
123
+ for (Node* n : block->nodes ()) {
124
+ if (n->kind () == prim::RaiseException) {
125
+ return true ;
126
+ }
127
+ }
128
+ return false ;
129
+ }
130
+
131
+ void EliminateExceptionsNew (Block* block) {
132
+ auto graph = block->owningGraph ();
133
+ // Generate false and true constant placeholders
134
+ Value* false_const = graph->insertConstant (IValue (false ));
135
+ Value* true_const = graph->insertConstant (IValue (true ));
136
+
137
+ // For each prim::If node, if either block certainly throws an exception
138
+ // Replace all uses of the node input with the logical opposite
139
+ for (Node* n : block->nodes ()) {
140
+ if (n->kind () == prim::If) {
141
+ Block* true_block = n->blocks ()[0 ];
142
+ Block* false_block = n->blocks ()[1 ];
143
+
144
+ if (certainlyThrows (true_block)) {
145
+ n->input (0 )->replaceAllUsesDominatedByNodeWith (n, false_const);
146
+ } else if (certainlyThrows (false_block)) {
147
+ n->input (0 )->replaceAllUsesDominatedByNodeWith (n, true_const);
148
+ }
149
+ }
150
+
151
+ // Inspect and replace all instances within subblocks of the current node
152
+ for (Block* subblock : n->blocks ()) {
153
+ EliminateExceptionsNew (subblock);
154
+ }
155
+ }
156
+ }
157
+
158
+ void EliminateExceptionsNew (std::shared_ptr<Graph>& graph) {
159
+ EliminateExceptionsNew (graph->block ());
160
+ }
161
+
111
162
} // namespace passes
112
163
} // namespace lowering
113
164
} // namespace core
0 commit comments