Skip to content

Commit edbe3ac

Browse files
author
Vremold
committed
Execute ShapeRefinementPipeline and canonicalizer multiple times to reason out more static shapes
1 parent 8cad02f commit edbe3ac

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ class LowerToBackendContractPass
213213

214214
OpPassManager pm(module.getOperationName());
215215
TorchLoweringPipelineOptions options;
216+
options.maxIterations = maxIterations;
216217
options.decompose = decompose;
217218
options.backendLegalOps = backendLegalOps;
218219
createTorchSimplificationPipeline(pm, options);

lib/Dialect/Torch/Transforms/Passes.cpp

+8-6
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,14 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
124124
pm.addNestedPass<func::FuncOp>(Torch::createMaximizeValueSemanticsPass());
125125
// Update the return op to return value tensors.
126126
pm.addPass(Torch::createRefinePublicReturnPass());
127-
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
128-
// Do shape refinement.
129-
// This should be run before RefineTypes (which primarily does dtype
130-
// inference), because Torch type promotion rules actually depend on the shape
131-
// of the operand.
132-
createTorchShapeRefinementPipeline(pm);
127+
for (int i = 0; i < options.maxIterations; ++i) {
128+
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
129+
// Do shape refinement.
130+
// This should be run before RefineTypes (which primarily does dtype
131+
// inference), because Torch type promotion rules actually depend on the shape
132+
// of the operand.
133+
createTorchShapeRefinementPipeline(pm);
134+
}
133135
// Refine types in the program, which mainly means inferring dtypes of ops.
134136
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
135137
// Propagate to ABI return types the shape/dtype information discovered by

0 commit comments

Comments
 (0)