Skip to content

Commit abfac56

Browse files
authored
[mlir][mesh] Make sharding propagation and spmdization work on FuncOpInterface (#84415)
Make them more general instead of only supporting `func::FuncOp`.
1 parent 8160139 commit abfac56

File tree

6 files changed

+20
-16
lines changed

6 files changed

+20
-16
lines changed

mlir/include/mlir/Dialect/Mesh/Transforms/Passes.td

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ include "mlir/Pass/PassBase.td"
1616
// ShardingPropagation
1717
//===----------------------------------------------------------------------===//
1818

19-
def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
19+
def ShardingPropagation : InterfacePass<"sharding-propagation", "mlir::FunctionOpInterface"> {
2020
let summary = "sharding propagation";
2121
let description = [{
2222
Propagates sharding information throughout the graph. After this pass, each
@@ -29,7 +29,7 @@ def ShardingPropagation : Pass<"sharding-propagation", "mlir::func::FuncOp"> {
2929
];
3030
}
3131

32-
def Spmdization : Pass<"mesh-spmdization", "mlir::func::FuncOp"> {
32+
def Spmdization : InterfacePass<"mesh-spmdization", "mlir::FunctionOpInterface"> {
3333
let summary = "Partition a function into SPMD form.";
3434
let description = [{
3535
This pass fits in right after a pass that annotates the function with

mlir/lib/Dialect/Mesh/Transforms/ShardingPropagation.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/Mesh/IR/MeshDialect.h"
1313
#include "mlir/Dialect/Mesh/IR/MeshOps.h"
1414
#include "mlir/Dialect/Mesh/Interfaces/ShardingInterface.h"
15+
#include "mlir/Interfaces/FunctionInterfaces.h"
1516
#include "mlir/Pass/Pass.h"
1617
#include "llvm/Support/Debug.h"
1718
#include <vector>
@@ -172,9 +173,9 @@ static LogicalResult visitOp(Operation *op, OpBuilder &builder) {
172173
struct ShardingPropagation
173174
: public mesh::impl::ShardingPropagationBase<ShardingPropagation> {
174175
void runOnOperation() override {
175-
func::FuncOp funcOp = getOperation();
176+
FunctionOpInterface funcOp = getOperation();
176177
MLIRContext *ctx = funcOp.getContext();
177-
Region &region = funcOp.getBody();
178+
Region &region = funcOp.getFunctionBody();
178179
OpBuilder builder(ctx);
179180
if (!region.hasOneBlock()) {
180181
funcOp.emitOpError() << "only one block is supported!";

mlir/lib/Dialect/Mesh/Transforms/Spmdization.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include "mlir/IR/MLIRContext.h"
2525
#include "mlir/IR/SymbolTable.h"
2626
#include "mlir/IR/Value.h"
27+
#include "mlir/Interfaces/ControlFlowInterfaces.h"
28+
#include "mlir/Interfaces/FunctionInterfaces.h"
2729
#include "mlir/Pass/Pass.h"
2830
#include "mlir/Support/LLVM.h"
2931
#include "mlir/Support/LogicalResult.h"
@@ -694,7 +696,7 @@ static LogicalResult spmdizeBlock(Block &block, IRMapping &spmdizationMap,
694696
}
695697

696698
static LogicalResult
697-
spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
699+
spmdizeFuncOp(FunctionOpInterface op, IRMapping &spmdizationMap,
698700
SymbolTableCollection &symbolTableCollection) {
699701
OpBuilder builder(op.getFunctionBody());
700702

@@ -717,21 +719,21 @@ spmdizeFuncOp(func::FuncOp op, IRMapping &spmdizationMap,
717719

718720
// Find a return op and change the function results signature to its operands
719721
// signature.
720-
func::ReturnOp returnOp;
721-
for (Block &block : op.getBody()) {
722+
Operation *returnOp = nullptr;
723+
for (Block &block : op.getFunctionBody()) {
722724
if (block.empty()) {
723725
continue;
724726
}
725727

726-
returnOp = llvm::cast<func::ReturnOp>(block.back());
727-
if (returnOp) {
728+
if (block.back().hasTrait<OpTrait::ReturnLike>()) {
729+
returnOp = &block.back();
728730
break;
729731
}
730732
}
731733
assert(returnOp);
732-
op.setFunctionType(FunctionType::get(op->getContext(),
733-
op.getBody().front().getArgumentTypes(),
734-
returnOp->getOperandTypes()));
734+
op.setType(FunctionType::get(op->getContext(),
735+
op.getFunctionBody().front().getArgumentTypes(),
736+
returnOp->getOperandTypes()));
735737

736738
return success();
737739
}

mlir/test/Dialect/Linalg/mesh-spmdization.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
// RUN: mlir-opt \
2-
// RUN: --mesh-spmdization \
3-
// RUN: --test-constant-fold \
2+
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
43
// RUN: --split-input-file \
54
// RUN: %s | FileCheck %s
65

mlir/test/Dialect/Mesh/sharding-propagation.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: mlir-opt -sharding-propagation %s | FileCheck %s
1+
// RUN: mlir-opt --pass-pipeline="builtin.module(func.func(sharding-propagation))" %s | FileCheck %s
22

33
mesh.mesh @mesh_1d(shape = ?)
44
mesh.mesh @mesh_2d(shape = 2x4)

mlir/test/Dialect/Mesh/spmdization.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
// RUN: mlir-opt --mesh-spmdization --test-constant-fold %s | FileCheck %s
1+
// RUN: mlir-opt \
2+
// RUN: --pass-pipeline="builtin.module(func.func(mesh-spmdization,test-constant-fold))" \
3+
// RUN: %s | FileCheck %s
24

35
mesh.mesh @mesh_1d(shape = 2)
46

0 commit comments

Comments
 (0)