Skip to content

Commit 38b90cb

Browse files
committed
[CIR][ThroughMLIR] Support lowering cir.do to scf.while
1 parent c289083 commit 38b90cb

File tree

2 files changed

+112
-3
lines changed

2 files changed

+112
-3
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,19 @@ class SCFWhileLoop {
6969
mlir::ConversionPatternRewriter *rewriter;
7070
};
7171

72+
class SCFDoLoop {
73+
public:
74+
SCFDoLoop(mlir::cir::DoWhileOp op, mlir::cir::DoWhileOp::Adaptor adaptor,
75+
mlir::ConversionPatternRewriter *rewriter)
76+
: DoOp(op), adaptor(adaptor), rewriter(rewriter) {}
77+
void transferToSCFWhileOp();
78+
79+
private:
80+
mlir::cir::DoWhileOp DoOp;
81+
mlir::cir::DoWhileOp::Adaptor adaptor;
82+
mlir::ConversionPatternRewriter *rewriter;
83+
};
84+
7285
static int64_t getConstant(mlir::cir::ConstantOp op) {
7386
auto attr = op->getAttrs().front().getValue();
7487
const auto IntAttr = attr.dyn_cast<mlir::cir::IntAttr>();
@@ -261,6 +274,40 @@ void SCFWhileLoop::transferToSCFWhileOp() {
261274
rewriter->eraseBlock(&scfWhileOp.getAfter().back());
262275
}
263276

277+
void SCFDoLoop::transferToSCFWhileOp() {
278+
// only support a simple do-while
279+
// FIXME: can not support nested do-while
280+
281+
auto scfWhileOp = rewriter->create<mlir::scf::WhileOp>(
282+
DoOp.getLoc(), DoOp->getResultTypes(), adaptor.getOperands());
283+
284+
rewriter->createBlock(&scfWhileOp.getBefore());
285+
rewriter->createBlock(&scfWhileOp.getAfter());
286+
287+
rewriter->cloneRegionBefore(DoOp.getBody(), &scfWhileOp.getBefore().back());
288+
rewriter->eraseBlock(&scfWhileOp.getBefore().back());
289+
290+
rewriter->cloneRegionBefore(DoOp.getCond(), &scfWhileOp.getAfter().back());
291+
rewriter->eraseBlock(&scfWhileOp.getAfter().back());
292+
293+
rewriter->inlineBlockBefore(&scfWhileOp.getAfter().back(),
294+
&scfWhileOp.getBefore().back(),
295+
scfWhileOp.getBefore().back().end());
296+
297+
rewriter->createBlock(&scfWhileOp.getAfter());
298+
299+
auto &beforeFrontBlock = scfWhileOp.getBefore().front();
300+
for (auto it = beforeFrontBlock.begin(); it != beforeFrontBlock.end(); ++it) {
301+
if (auto yieldOp = llvm::dyn_cast<mlir::cir::YieldOp>(&*it)) {
302+
rewriter->eraseOp(yieldOp);
303+
break;
304+
}
305+
}
306+
307+
rewriter->setInsertionPointToEnd(&scfWhileOp.getAfter().front());
308+
rewriter->create<mlir::scf::YieldOp>(DoOp.getLoc());
309+
}
310+
264311
class CIRForOpLowering : public mlir::OpConversionPattern<mlir::cir::ForOp> {
265312
public:
266313
using OpConversionPattern<mlir::cir::ForOp>::OpConversionPattern;
@@ -291,6 +338,20 @@ class CIRWhileOpLowering
291338
}
292339
};
293340

341+
class CIRDoOpLowering : public mlir::OpConversionPattern<mlir::cir::DoWhileOp> {
342+
public:
343+
using OpConversionPattern<mlir::cir::DoWhileOp>::OpConversionPattern;
344+
345+
mlir::LogicalResult
346+
matchAndRewrite(mlir::cir::DoWhileOp op, OpAdaptor adaptor,
347+
mlir::ConversionPatternRewriter &rewriter) const override {
348+
SCFDoLoop loop(op, adaptor, &rewriter);
349+
loop.transferToSCFWhileOp();
350+
rewriter.eraseOp(op);
351+
return mlir::success();
352+
}
353+
};
354+
294355
class CIRConditionOpLowering
295356
: public mlir::OpConversionPattern<mlir::cir::ConditionOp> {
296357
public:
@@ -314,8 +375,8 @@ class CIRConditionOpLowering
314375

315376
void populateCIRLoopToSCFConversionPatterns(mlir::RewritePatternSet &patterns,
316377
mlir::TypeConverter &converter) {
317-
patterns.add<CIRForOpLowering, CIRWhileOpLowering, CIRConditionOpLowering>(
318-
converter, patterns.getContext());
378+
patterns.add<CIRForOpLowering, CIRWhileOpLowering, CIRConditionOpLowering,
379+
CIRDoOpLowering>(converter, patterns.getContext());
319380
}
320381

321-
} // namespace cir
382+
} // namespace cir
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir %s -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
4+
int sum() {
5+
int s = 0;
6+
int i = 0;
7+
do {
8+
s += i;
9+
++i;
10+
} while (i <= 10);
11+
return s;
12+
}
13+
14+
// CHECK: func.func @sum() -> i32 {
15+
// CHECK: %[[ALLOC:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
16+
// CHECK: %[[ALLOC0:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
17+
// CHECK: %[[ALLOC1:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
18+
// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
19+
// CHECK: memref.store %[[C0_I32]], %[[ALLOC0]][] : memref<i32>
20+
// CHECK: %[[C0_I32_2:.+]] = arith.constant 0 : i32
21+
// CHECK: memref.store %[[C0_I32_2]], %[[ALLOC1]][] : memref<i32>
22+
// CHECK: memref.alloca_scope {
23+
// CHECK: scf.while : () -> () {
24+
// CHECK: %[[VAR1:.+]] = memref.load %[[ALLOC1]][] : memref<i32>
25+
// CHECK: %[[VAR2:.+]] = memref.load %[[ALLOC0]][] : memref<i32>
26+
// CHECK: %[[ADD:.+]] = arith.addi %[[VAR2]], %[[VAR1]] : i32
27+
// CHECK: memref.store %[[ADD]], %[[ALLOC0]][] : memref<i32>
28+
// CHECK: %[[VAR3:.+]] = memref.load %[[ALLOC1]][] : memref<i32>
29+
// CHECK: %[[C1_I32:.+]] = arith.constant 1 : i32
30+
// CHECK: %[[ADD1:.+]] = arith.addi %[[VAR3]], %[[C1_I32]] : i32
31+
// CHECK: memref.store %[[ADD1]], %[[ALLOC1]][] : memref<i32>
32+
// CHECK: %[[VAR4:.+]] = memref.load %[[ALLOC1]][] : memref<i32>
33+
// CHECK: %[[C10_I32:.+]] = arith.constant 10 : i32
34+
// CHECK: %[[CMP:.+]] = arith.cmpi ule, %[[VAR4]], %[[C10_I32]] : i32
35+
// CHECK: %[[EXT:.+]] = arith.extui %[[CMP]] : i1 to i32
36+
// CHECK: %[[C0_I32_3:.+]] = arith.constant 0 : i32
37+
// CHECK: %[[NE:.+]] = arith.cmpi ne, %[[EXT]], %[[C0_I32_3]] : i32
38+
// CHECK: %[[EXT1:.+]] = arith.extui %[[NE]] : i1 to i8
39+
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[EXT1]] : i8 to i1
40+
// CHECK: scf.condition(%[[TRUNC]])
41+
// CHECK: } do {
42+
// CHECK: scf.yield
43+
// CHECK: }
44+
// CHECK: }
45+
// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC0]][] : memref<i32>
46+
// CHECK: memref.store %[[LOAD]], %[[ALLOC]][] : memref<i32>
47+
// CHECK: %[[RET:.+]] = memref.load %[[ALLOC]][] : memref<i32>
48+
// CHECK: return %[[RET]] : i32

0 commit comments

Comments
 (0)