Skip to content

Commit 972a717

Browse files
committed
fix: Repair EliminateExceptions lowering pass
- Update EliminateExceptions to use `replaceAllUsesDominatedByNodeWith` instead of `replaceAllUsesWith` to avoid issue with invalid IR causing program halting
1 parent c16bc3a commit 972a717

File tree

3 files changed

+53
-0
lines changed

3 files changed

+53
-0
lines changed

core/lowering/lowering.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::I
104104
torch::jit::InlineFunctionalGraphs(g);
105105
torch::jit::PeepholeOptimize(g, false);
106106
torch::jit::FuseLinear(g);
107+
passes::EliminateExceptionsNew(g);
107108
if (!lower_info.disable_cse) {
108109
torch::jit::EliminateCommonSubexpression(g);
109110
}

core/lowering/passes/exception_elimination.cpp

+51
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,57 @@ void EliminateExceptionOrPassPattern(std::shared_ptr<Graph> graph) {
108108
}
109109
}
110110

111+
/*
112+
Below is a fork of the torch::jit::EliminateExceptions pass, with node replacement
113+
using replaceAllUsesDominatedByNodeWith instead of replaceAllUsesWith,
114+
so as to not invalidate the IR in challenging cases, such as nested Ifs
115+
116+
Original Source from which it was adapted:
117+
https://github.com/pytorch/pytorch/blob/c29ab84115f40614d04e4557ea2e1ac40b7aa75c/torch/csrc/jit/passes/remove_exceptions.cpp
118+
*/
119+
120+
bool certainlyThrows(Block* block) {
121+
// A block certainly throws an exception if it contains
122+
// the prim::RaiseException operation
123+
for (Node* n : block->nodes()) {
124+
if (n->kind() == prim::RaiseException) {
125+
return true;
126+
}
127+
}
128+
return false;
129+
}
130+
131+
void EliminateExceptionsNew(Block* block) {
132+
auto graph = block->owningGraph();
133+
// Generate false and true constant placeholders
134+
Value* false_const = graph->insertConstant(IValue(false));
135+
Value* true_const = graph->insertConstant(IValue(true));
136+
137+
// For each prim::If node, if either block certainly throws an exception
138+
// Replace all uses of the node input with the logical opposite
139+
for (Node* n : block->nodes()) {
140+
if (n->kind() == prim::If) {
141+
Block* true_block = n->blocks()[0];
142+
Block* false_block = n->blocks()[1];
143+
144+
if (certainlyThrows(true_block)) {
145+
n->input(0)->replaceAllUsesDominatedByNodeWith(n, false_const);
146+
} else if (certainlyThrows(false_block)) {
147+
n->input(0)->replaceAllUsesDominatedByNodeWith(n, true_const);
148+
}
149+
}
150+
151+
// Inspect and replace all instances within subblocks of the current node
152+
for (Block* subblock : n->blocks()) {
153+
EliminateExceptionsNew(subblock);
154+
}
155+
}
156+
}
157+
158+
void EliminateExceptionsNew(std::shared_ptr<Graph>& graph) {
159+
EliminateExceptionsNew(graph->block());
160+
}
161+
111162
} // namespace passes
112163
} // namespace lowering
113164
} // namespace core

core/lowering/passes/passes.h

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
2020
void ConvTransposed3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
2121
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);
2222
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph);
23+
void EliminateExceptionsNew(std::shared_ptr<torch::jit::Graph>& graph);
2324
void EliminateExceptionOrPassPattern(std::shared_ptr<torch::jit::Graph> graph);
2425
void ReduceToOperation(std::shared_ptr<torch::jit::Graph>& graph);
2526
void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph);

0 commit comments

Comments
 (0)