Skip to content

Commit 4f30ece

Browse files
author
Vremold
committed
Iteratively run SimplificationPipeline until code optimization converges
1 parent d7d6797 commit 4f30ece

File tree

1 file changed

+45
-1
lines changed

1 file changed

+45
-1
lines changed

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

+45-1
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,40 @@ static bool satisfiesBackendContract(ModuleOp module,
197197
return true;
198198
}
199199

200+
// Forward declaration
201+
static llvm::hash_code hashOperation(Operation *);
202+
static llvm::hash_code hashBlock(Block &block) {
203+
llvm::hash_code hash(0);
204+
for (Operation &op : block.getOperations()) {
205+
llvm::hash_code opHash = hashOperation(&op);
206+
hash = llvm::hash_combine(hash, opHash);
207+
}
208+
return hash;
209+
}
210+
211+
static llvm::hash_code hashRegion(Region &region) {
212+
llvm::hash_code hash(0);
213+
for (Block &block : region.getBlocks()) {
214+
llvm::hash_code blockHash = hashBlock(block);
215+
hash = llvm::hash_combine(hash, blockHash);
216+
}
217+
return hash;
218+
}
219+
220+
static llvm::hash_code hashOperation(Operation *op) {
221+
llvm::hash_code hash(0);
222+
llvm::hash_code opHash = OperationEquivalence::computeHash(
223+
op, OperationEquivalence::ignoreHashValue,
224+
OperationEquivalence::ignoreHashValue,
225+
OperationEquivalence::IgnoreLocations);
226+
hash = llvm::hash_combine(hash, opHash);
227+
for (auto &region : op->getRegions()) {
228+
llvm::hash_code regionHash = hashRegion(region);
229+
hash = llvm::hash_combine(hash, regionHash);
230+
}
231+
return hash;
232+
}
233+
200234
namespace {
201235
class LowerToBackendContractPass
202236
: public LowerToBackendContractBase<LowerToBackendContractPass> {
@@ -217,6 +251,9 @@ class LowerToBackendContractPass
217251
options.backendLegalOps = backendLegalOps;
218252
createTorchSimplificationPipeline(pm, options);
219253

254+
bool codeChanged = false;
255+
llvm::hash_code moduleHash = hashOperation(module);
256+
220257
int i = 0;
221258
do {
222259
if (i++ == maxIterations) {
@@ -234,7 +271,14 @@ class LowerToBackendContractPass
234271

235272
if (failed(runPipeline(pm, module)))
236273
return signalPassFailure();
237-
} while (!satisfiesBackendContract(module));
274+
275+
llvm::hash_code newModuleHash = hashOperation(module);
276+
codeChanged = (moduleHash != newModuleHash);
277+
moduleHash = newModuleHash;
278+
279+
// Iterate until maxIterations is reached or
280+
// backend contract is satisified and code optimization converges.
281+
} while (!satisfiesBackendContract(module) || codeChanged);
238282
LLVM_DEBUG({
239283
llvm::dbgs() << "LowerToBackendContractPass: "
240284
<< "succeeded after " << i

0 commit comments

Comments
 (0)