Skip to content

Commit 1adbd37

Browse files
author
Vremold
committed
Execute ShapeRefinementPipeline and canonicalizer multiple times to reason out more static shapes
1 parent c38308f commit 1adbd37

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ class LowerToBackendContractPass
212212

213213
OpPassManager pm(module.getOperationName());
214214
TorchLoweringPipelineOptions options;
215+
options.maxIterations = maxIterations;
215216
options.decompose = decompose;
216217
options.backendLegalOps = backendLegalOps;
217218
createTorchSimplificationPipeline(pm, options);

lib/Dialect/Torch/Transforms/Passes.cpp

+8-6
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,14 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
124124
pm.addNestedPass<func::FuncOp>(Torch::createMaximizeValueSemanticsPass());
125125
// Update the return op to return value tensors.
126126
pm.addPass(Torch::createRefinePublicReturnPass());
127-
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
128-
// Do shape refinement.
129-
// This should be run before RefineTypes (which primarily does dtype
130-
// inference), because Torch type promotion rules actually depend on the shape
131-
// of the operand.
132-
createTorchShapeRefinementPipeline(pm);
127+
for (int i = 0; i < options.maxIterations; ++i) {
128+
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
129+
// Do shape refinement.
130+
// This should be run before RefineTypes (which primarily does dtype
131+
// inference), because Torch type promotion rules actually depend on the shape
132+
// of the operand.
133+
createTorchShapeRefinementPipeline(pm);
134+
}
133135
// Refine types in the program, which mainly means inferring dtypes of ops.
134136
pm.addNestedPass<func::FuncOp>(Torch::createRefineTypesPass());
135137
// Propagate to ABI return types the shape/dtype information discovered by

0 commit comments

Comments
 (0)