Skip to content

Commit 57681f7

Browse files
committed
Iteratively run the main simplification pipeline.
This introduces a new pass LowerToBackendContract (better name very welcome) which performs the bulk of the simplifications that we do, such as - shape refinement - dtype refinement - maximizing value semantics - inlining global slots - decomposing complex ops The key difference from before is that it iterates the set of transformations, which can help to break a number of "catch-22" issues where one simplification depends on another, the latest example being here: #1131 This also exposed that RefineTypes was sometimes crashing/asserting for certain inputs. This commit hardens it a bit.
1 parent 9c8b962 commit 57681f7

File tree

14 files changed

+518
-251
lines changed

14 files changed

+518
-251
lines changed

include/torch-mlir/Dialect/Torch/Transforms/Passes.h

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,21 @@ createPrepareForGlobalizeObjectGraphPass();
2828

2929
struct TorchLoweringPipelineOptions
3030
: public PassPipelineOptions<TorchLoweringPipelineOptions> {
31-
// If this option is true, then perform optimizations.
32-
// If this option is false, only do the bare minimum for correctness.
33-
Option<bool> optimize{*this, "optimize", llvm::cl::desc("Do optimizations."),
34-
llvm::cl::init(true)};
35-
31+
// The maximum number of invocations of the simplification pipeline in
32+
// LowerToBackendContract.
33+
Option<int> maxIterations{
34+
*this, "max-iterations",
35+
llvm::cl::desc(
36+
"Maximum number of invocations of the simplification pipeline."),
37+
llvm::cl::init(10)};
3638
// If this option is false, decompose complex operations.
3739
// If this option is true, skip decomposition of complex operations.
38-
Option<bool> decompose{*this, "decompose-complex-ops", llvm::cl::desc("Decompose complex operations."),
39-
llvm::cl::init(true)};
40+
// TODO: This should be replaced with a list of operations to decompose.
41+
// (or some other way to specify the set of allowed ops in the backend
42+
// contract)
43+
Option<bool> decompose{*this, "decompose-complex-ops",
44+
llvm::cl::desc("Decompose complex operations."),
45+
llvm::cl::init(true)};
4046
};
4147

4248
/// Creates a pipeline that lowers the object graph IR that is produced by
@@ -50,10 +56,16 @@ void createTorchScriptModuleToTorchBackendPipeline(
5056
void createTorchFunctionToTorchBackendPipeline(
5157
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
5258

53-
/// Creates a pipeline that refines shapes of tensor operations in the program.
54-
void createTorchShapeRefinementPipeline(
59+
/// Creates a pipeline that simplifies the computations in the program.
60+
/// This pass does not do any global program restructuring -- it works entirely
61+
/// within a single semantic model of a `builtin.module` with
62+
/// `torch.global_slot` ops and `func.func` ops.
63+
void createTorchSimplificationPipeline(
5564
OpPassManager &pm, const TorchLoweringPipelineOptions &options);
5665

66+
/// Creates a pipeline that refines shapes of tensor operations in the program.
67+
void createTorchShapeRefinementPipeline(OpPassManager &pm);
68+
5769
std::unique_ptr<OperationPass<ModuleOp>> createAdjustCallingConventionsPass();
5870

5971
std::unique_ptr<OperationPass<func::FuncOp>> createRefineTypesPass();
@@ -78,10 +90,10 @@ createSimplifyShapeCalculationsPass();
7890
std::unique_ptr<OperationPass<func::FuncOp>> createDropShapeCalculationsPass();
7991

8092
std::unique_ptr<OperationPass<ModuleOp>>
81-
createVerifyConversionToValueSemanticsPass();
93+
createEraseModuleInitializerPass();
8294

8395
std::unique_ptr<OperationPass<ModuleOp>>
84-
createEraseModuleInitializerPass();
96+
createLowerToBackendContractPass(int maxIterations, bool decompose);
8597

8698
StringRef getShapeLibrary();
8799

include/torch-mlir/Dialect/Torch/Transforms/Passes.td

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -253,18 +253,6 @@ def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "func::FuncOp"
253253
}];
254254
}
255255

256-
def VerifyConversionToValueSemantics
257-
: Pass<"torch-verify-conversion-to-value-semantics", "ModuleOp"> {
258-
let summary = "Verify that all tensors have been converted to value semantics";
259-
let constructor =
260-
"mlir::torch::Torch::createVerifyConversionToValueSemanticsPass()";
261-
let description = [{
262-
Prior passes in the pipeline may have missed converting all tensors to value
263-
semantics and we wish to catch such failures early instead of fixing
264-
individual cases downstream.
265-
}];
266-
}
267-
268256
def EraseModuleInitializer
269257
: Pass<"torch-erase-module-initializer", "ModuleOp"> {
270258
let summary = "Erase the `torch.global_slot.module_initializer` op.";
@@ -273,9 +261,64 @@ def EraseModuleInitializer
273261
let description = [{
274262
Backends cannot currently handle module initializers, so we omit them from
275263
our backend contract. This pass removes the
276-
`torch.global_slot.module_initializer` op from the module if legal, or
277-
raises an error.
264+
`torch.global_slot.module_initializer` op from the module if legal.
265+
}];
266+
}
267+
268+
def LowerToBackendContract
269+
: Pass<"torch-lower-to-backend-contract", "ModuleOp"> {
270+
let summary = "Perform simplifications until the backend contract is satisfied.";
271+
let constructor = [{
272+
mlir::torch::Torch::createLowerToBackendContractPass(
273+
/*maxIterations=*/10, /*decompose=*/true)
274+
}];
275+
let description = [{
276+
This pass performs the bulk of the lowering of the program's computations
277+
to the backend contract. This pass does not do any global program
278+
restructuring -- it works entirely within a single semantic model
279+
of a `builtin.module` with `torch.global_slot` ops and `func.func` ops.
280+
281+
This pass runs a set of simplifications within that semantic model until
282+
the backend contract is satisfied, and fails if it cannot be satisfied.
283+
In particular, the backend contract consists of:
284+
- Tensors
285+
- Have been converted to value semantics.
286+
- Have at least a known rank, though ideally a maximally inferred shape.
287+
- Have a known dtype.
288+
- `torch.global_slot`'s have been eliminated from the program.
289+
- Ops have been decomposed.
290+
291+
This particular choice of backend contract was born out of a common set of
292+
requirements from backends, along with aligning with long-term PyTorch
293+
direction of being more tracing-based. The set of simplifications performed
294+
here can be thought of as simulating the kinds of simplifications that
295+
happen naturally as part of tracing, but in a way that is applicable
296+
to our TorchScript frontend. For the LazyTensorCore frontend, the backend
297+
contract trivially holds (except for certain decompositions).
298+
299+
Generally it is not desirable to have a compiler where successful
300+
compilation depends on "optimizing hard enough", but in this case, there
301+
seems to be enough alignment and recognition in the industry that the
302+
Python-based programming model in the source program is too dynamic
303+
to feasibly handle in totality without a tracing approach that has access
304+
to the source program to re-trace in the face of dynamism (e.g. the ability
305+
to do what TorchDynamo calls "graph break"). We are attempting to maintain
306+
a practical compiler that works well given the current set of constraints
307+
of the TorchScript frontend that PyTorch provides us, and are working to
308+
co-design PyTorch's direction so that we land in a place where most of this
309+
"optimizing hard enough" is not necessary.
278310
}];
311+
let options = [
312+
Option<"maxIterations", "max-iterations", "int", /*default=*/"10",
313+
"Maximum number of invocations of the simplification pipeline.">,
314+
// TODO: Make this a configurable set of ops.
315+
Option<"decompose", "decompose", "bool", /*default=*/"true",
316+
"Decompose ops.">
317+
318+
];
319+
// TODO: Debug why this is needed, even though the input program has func.func
320+
// ops in it.
321+
let dependentDialects = ["func::FuncDialect"];
279322
}
280323

281324
#endif // TORCHMLIR_TORCH_PASSES

lib/Dialect/Torch/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ add_mlir_library(TorchMLIRTorchPasses
66
Passes.cpp
77
GlobalizeObjectGraph.cpp
88
InlineGlobalSlots.cpp
9+
LowerToBackendContract.cpp
910
MaximizeValueSemantics.cpp
1011
PrepareForGlobalizeObjectGraph.cpp
1112
ReduceOpVariants.cpp
@@ -14,7 +15,6 @@ add_mlir_library(TorchMLIRTorchPasses
1415
ReifyShapeCalculations.cpp
1516
ShapeLibrary.cpp
1617
SimplifyShapeCalculations.cpp
17-
VerifyConversionToValueSemantics.cpp
1818

1919
ADDITIONAL_HEADER_DIRS
2020
${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/Torch/Transforms

lib/Dialect/Torch/Transforms/EraseModuleInitializer.cpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,15 @@ namespace {
2727
class EraseModuleInitializerPass
2828
: public EraseModuleInitializerBase<EraseModuleInitializerPass> {
2929
void runOnOperation() override {
30-
auto walkResult = getOperation().walk([](GlobalSlotModuleInitializerOp op) {
30+
for (auto initializer :
31+
getOperation().getOps<GlobalSlotModuleInitializerOp>()) {
3132
auto intialize =
32-
cast<InitializeGlobalSlotsOp>(op.getBody()->getTerminator());
33-
if (intialize.getNumOperands() != 0) {
34-
op.emitError("could not erase non-empty module initializer");
35-
return WalkResult::interrupt();
33+
cast<InitializeGlobalSlotsOp>(initializer.getBody()->getTerminator());
34+
if (intialize.getNumOperands() == 0) {
35+
initializer.erase();
3636
}
37-
op.erase();
38-
return WalkResult::advance();
39-
});
40-
if (walkResult.wasInterrupted()) {
41-
return signalPassFailure();
37+
// The verifier ensures there is only one GlobalSlotModuleInitializerOp.
38+
break;
4239
}
4340
}
4441
};

0 commit comments

Comments
 (0)