File tree 2 files changed +9
-6
lines changed
lib/Dialect/Torch/Transforms
2 files changed +9
-6
lines changed Original file line number Diff line number Diff line change @@ -213,6 +213,7 @@ class LowerToBackendContractPass
213
213
214
214
OpPassManager pm (module.getOperationName ());
215
215
TorchLoweringPipelineOptions options;
216
+ options.maxIterations = maxIterations;
216
217
options.decompose = decompose;
217
218
options.backendLegalOps = backendLegalOps;
218
219
createTorchSimplificationPipeline (pm, options);
Original file line number Diff line number Diff line change @@ -124,12 +124,14 @@ void mlir::torch::Torch::createTorchSimplificationPipeline(
124
124
pm.addNestedPass <func::FuncOp>(Torch::createMaximizeValueSemanticsPass ());
125
125
// Update the return op to return value tensors.
126
126
pm.addPass (Torch::createRefinePublicReturnPass ());
127
- pm.addNestedPass <func::FuncOp>(createCanonicalizerPass ());
128
- // Do shape refinement.
129
- // This should be run before RefineTypes (which primarily does dtype
130
- // inference), because Torch type promotion rules actually depend on the shape
131
- // of the operand.
132
- createTorchShapeRefinementPipeline (pm);
127
+ for (int i = 0 ; i < options.maxIterations ; ++i) {
128
+ pm.addNestedPass <func::FuncOp>(createCanonicalizerPass ());
129
+ // Do shape refinement.
130
+ // This should be run before RefineTypes (which primarily does dtype
131
+ // inference), because Torch type promotion rules actually depend on the shape
132
+ // of the operand.
133
+ createTorchShapeRefinementPipeline (pm);
134
+ }
133
135
// Refine types in the program, which mainly means inferring dtypes of ops.
134
136
pm.addNestedPass <func::FuncOp>(Torch::createRefineTypesPass ());
135
137
// Propagate to ABI return types the shape/dtype information discovered by
You can’t perform that action at this time.
0 commit comments