@@ -253,18 +253,6 @@ def DropShapeCalculations : Pass<"torch-drop-shape-calculations", "func::FuncOp"
253
253
}];
254
254
}
255
255
256
- def VerifyConversionToValueSemantics
257
- : Pass<"torch-verify-conversion-to-value-semantics", "ModuleOp"> {
258
- let summary = "Verify that all tensors have been converted to value semantics";
259
- let constructor =
260
- "mlir::torch::Torch::createVerifyConversionToValueSemanticsPass()";
261
- let description = [{
262
- Prior passes in the pipeline may have missed converting all tensors to value
263
- semantics and we wish to catch such failures early instead of fixing
264
- individual cases downstream.
265
- }];
266
- }
267
-
268
256
def EraseModuleInitializer
269
257
: Pass<"torch-erase-module-initializer", "ModuleOp"> {
270
258
let summary = "Erase the `torch.global_slot.module_initializer` op.";
@@ -273,9 +261,64 @@ def EraseModuleInitializer
273
261
let description = [{
274
262
Backends cannot currently handle module initializers, so we omit them from
275
263
our backend contract. This pass removes the
276
- `torch.global_slot.module_initializer` op from the module if legal, or
277
- raises an error.
264
+ `torch.global_slot.module_initializer` op from the module if legal.
265
+ }];
266
+ }
267
+
268
+ def LowerToBackendContract
269
+ : Pass<"torch-lower-to-backend-contract", "ModuleOp"> {
270
+ let summary = "Perform simplifications until the backend contract is satisfied.";
271
+ let constructor = [{
272
+ mlir::torch::Torch::createLowerToBackendContractPass(
273
+ /*maxIterations=*/10, /*decompose=*/true)
274
+ }];
275
+ let description = [{
276
+ This pass performs the bulk of the lowering of the program's computations
277
+ to the backend contract. This pass does not do any global program
278
+ restructuring -- it works entirely within a single semantic model
279
+ of a `builtin.module` with `torch.global_slot` ops and `func.func` ops.
280
+
281
+ This pass runs a set of simplifications within that semantic model until
282
+ the backend contract is satisfied, and fails if it cannot be satisfied.
283
+ In particular, the backend contract consists of:
284
+ - Tensors
285
+ - Have been converted to value semantics.
286
+ - Have at least a known rank, though ideally a maximally inferred shape.
287
+ - Have a known dtype.
288
+ - `torch.global_slot`'s have been eliminated from the program.
289
+ - Ops have been decomposed.
290
+
291
+ This particular choice of backend contract was born out of a common set of
292
+ requirements from backends, along with aligning with long-term PyTorch
293
+ direction of being more tracing-based. The set of simplifications performed
294
+ here can be thought of as simulating the kinds of simplifications that
295
+ happen naturally as part of tracing, but in a way that is applicable
296
+ to our TorchScript frontend. For the LazyTensorCore frontend, the backend
297
+ contract trivially holds (except for certain decompositions).
298
+
299
+ Generally it is not desirable to have a compiler where successful
300
+ compilation depends on "optimizing hard enough", but in this case, there
301
+ seems to be enough alignment and recognition in the industry that the
302
+ Python-based programming model in the source program is too dynamic
303
+ to feasibly handle in totality without a tracing approach that has access
304
+ to the source program to re-trace in the face of dynamism (e.g. the ability
305
+ to do what TorchDynamo calls "graph break"). We are attempting to maintain
306
+ a practical compiler that works well given the current set of constraints
307
+ of the TorchScript frontend that PyTorch provides us, and are working to
308
+ co-design PyTorch's direction so that we land in a place where most of this
309
+ "optimizing hard enough" is not necessary.
278
310
}];
311
+ let options = [
312
+ Option<"maxIterations", "max-iterations", "int", /*default=*/"10",
313
+ "Maximum number of invocations of the simplification pipeline.">,
314
+ // TODO: Make this a configurable set of ops.
315
+ Option<"decompose", "decompose", "bool", /*default=*/"true",
316
+ "Decompose ops.">
317
+
318
+ ];
319
+ // TODO: Debug why this is needed, even though the input program has func.func
320
+ // ops in it.
321
+ let dependentDialects = ["func::FuncDialect"];
279
322
}
280
323
281
324
#endif // TORCHMLIR_TORCH_PASSES
0 commit comments