Skip to content

Commit 4993f97

Browse files
GaoXiangYalanza
authored andcommitted
[CIR][ThroughMLIR] Lowering cir.do to scf.while,and fix cir.while lowering bugs. (#756)
In this pr, I lowering cir.do to scf.while, fix cir.while nested loop bugs and add test cases.
1 parent c7b79c2 commit 4993f97

File tree

3 files changed

+227
-10
lines changed

3 files changed

+227
-10
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,11 @@
1414
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1515
#include "mlir/Dialect/SCF/IR/SCF.h"
1616
#include "mlir/Dialect/SCF/Transforms/Passes.h"
17+
#include "mlir/IR/Builders.h"
1718
#include "mlir/IR/BuiltinDialect.h"
1819
#include "mlir/IR/BuiltinTypes.h"
20+
#include "mlir/IR/Location.h"
21+
#include "mlir/IR/ValueRange.h"
1922
#include "mlir/Pass/Pass.h"
2023
#include "mlir/Pass/PassManager.h"
2124
#include "mlir/Support/LogicalResult.h"
@@ -69,6 +72,19 @@ class SCFWhileLoop {
6972
mlir::ConversionPatternRewriter *rewriter;
7073
};
7174

75+
class SCFDoLoop {
76+
public:
77+
SCFDoLoop(mlir::cir::DoWhileOp op, mlir::cir::DoWhileOp::Adaptor adaptor,
78+
mlir::ConversionPatternRewriter *rewriter)
79+
: DoOp(op), adaptor(adaptor), rewriter(rewriter) {}
80+
void transferToSCFWhileOp();
81+
82+
private:
83+
mlir::cir::DoWhileOp DoOp;
84+
mlir::cir::DoWhileOp::Adaptor adaptor;
85+
mlir::ConversionPatternRewriter *rewriter;
86+
};
87+
7288
static int64_t getConstant(mlir::cir::ConstantOp op) {
7389
auto attr = op->getAttrs().front().getValue();
7490
const auto IntAttr = mlir::dyn_cast<mlir::cir::IntAttr>(attr);
@@ -240,13 +256,33 @@ void SCFWhileLoop::transferToSCFWhileOp() {
240256
whileOp->getLoc(), whileOp->getResultTypes(), adaptor.getOperands());
241257
rewriter->createBlock(&scfWhileOp.getBefore());
242258
rewriter->createBlock(&scfWhileOp.getAfter());
259+
rewriter->inlineBlockBefore(&whileOp.getCond().front(),
260+
scfWhileOp.getBeforeBody(),
261+
scfWhileOp.getBeforeBody()->end());
262+
rewriter->inlineBlockBefore(&whileOp.getBody().front(),
263+
scfWhileOp.getAfterBody(),
264+
scfWhileOp.getAfterBody()->end());
265+
}
243266

244-
rewriter->cloneRegionBefore(whileOp.getCond(),
245-
&scfWhileOp.getBefore().back());
246-
rewriter->eraseBlock(&scfWhileOp.getBefore().back());
247-
248-
rewriter->cloneRegionBefore(whileOp.getBody(), &scfWhileOp.getAfter().back());
249-
rewriter->eraseBlock(&scfWhileOp.getAfter().back());
267+
void SCFDoLoop::transferToSCFWhileOp() {
268+
269+
auto beforeBuilder = [&](mlir::OpBuilder &builder, mlir::Location loc,
270+
mlir::ValueRange args) {
271+
auto *newBlock = builder.getBlock();
272+
rewriter->mergeBlocks(&DoOp.getBody().front(), newBlock);
273+
auto *yieldOp = newBlock->getTerminator();
274+
rewriter->mergeBlocks(&DoOp.getCond().front(), newBlock,
275+
yieldOp->getResults());
276+
rewriter->eraseOp(yieldOp);
277+
};
278+
auto afterBuilder = [&](mlir::OpBuilder &builder, mlir::Location loc,
279+
mlir::ValueRange args) {
280+
rewriter->create<mlir::scf::YieldOp>(loc, args);
281+
};
282+
283+
rewriter->create<mlir::scf::WhileOp>(DoOp.getLoc(), DoOp->getResultTypes(),
284+
adaptor.getOperands(), beforeBuilder,
285+
afterBuilder);
250286
}
251287

252288
class CIRForOpLowering : public mlir::OpConversionPattern<mlir::cir::ForOp> {
@@ -279,6 +315,20 @@ class CIRWhileOpLowering
279315
}
280316
};
281317

318+
class CIRDoOpLowering : public mlir::OpConversionPattern<mlir::cir::DoWhileOp> {
319+
public:
320+
using OpConversionPattern<mlir::cir::DoWhileOp>::OpConversionPattern;
321+
322+
mlir::LogicalResult
323+
matchAndRewrite(mlir::cir::DoWhileOp op, OpAdaptor adaptor,
324+
mlir::ConversionPatternRewriter &rewriter) const override {
325+
SCFDoLoop loop(op, adaptor, &rewriter);
326+
loop.transferToSCFWhileOp();
327+
rewriter.eraseOp(op);
328+
return mlir::success();
329+
}
330+
};
331+
282332
class CIRConditionOpLowering
283333
: public mlir::OpConversionPattern<mlir::cir::ConditionOp> {
284334
public:
@@ -302,8 +352,8 @@ class CIRConditionOpLowering
302352

303353
void populateCIRLoopToSCFConversionPatterns(mlir::RewritePatternSet &patterns,
304354
mlir::TypeConverter &converter) {
305-
patterns.add<CIRForOpLowering, CIRWhileOpLowering, CIRConditionOpLowering>(
306-
converter, patterns.getContext());
355+
patterns.add<CIRForOpLowering, CIRWhileOpLowering, CIRConditionOpLowering,
356+
CIRDoOpLowering>(converter, patterns.getContext());
307357
}
308358

309359
} // namespace cir
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir %s -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
4+
int sum() {
5+
int s = 0;
6+
int i = 0;
7+
do {
8+
s += i;
9+
++i;
10+
} while (i <= 10);
11+
return s;
12+
}
13+
14+
void nestedDoWhile() {
15+
int a = 0;
16+
do {
17+
a++;
18+
int b = 0;
19+
while(b < 2) {
20+
b++;
21+
}
22+
}while(a < 2);
23+
}
24+
25+
// CHECK: func.func @sum() -> i32 {
26+
// CHECK: %[[ALLOC:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
27+
// CHECK: %[[ALLOC0:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
28+
// CHECK: %[[ALLOC1:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
29+
// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
30+
// CHECK: memref.store %[[C0_I32]], %[[ALLOC0]][] : memref<i32>
31+
// CHECK: %[[C0_I32_2:.+]] = arith.constant 0 : i32
32+
// CHECK: memref.store %[[C0_I32_2]], %[[ALLOC1]][] : memref<i32>
33+
// CHECK: memref.alloca_scope {
34+
// CHECK: scf.while : () -> () {
35+
// CHECK: %[[VAR1:.+]] = memref.load %[[ALLOC1]][] : memref<i32>
36+
// CHECK: %[[VAR2:.+]] = memref.load %[[ALLOC0]][] : memref<i32>
37+
// CHECK: %[[ADD:.+]] = arith.addi %[[VAR2]], %[[VAR1]] : i32
38+
// CHECK: memref.store %[[ADD]], %[[ALLOC0]][] : memref<i32>
39+
// CHECK: %[[VAR3:.+]] = memref.load %[[ALLOC1]][] : memref<i32>
40+
// CHECK: %[[C1_I32:.+]] = arith.constant 1 : i32
41+
// CHECK: %[[ADD1:.+]] = arith.addi %[[VAR3]], %[[C1_I32]] : i32
42+
// CHECK: memref.store %[[ADD1]], %[[ALLOC1]][] : memref<i32>
43+
// CHECK: %[[VAR4:.+]] = memref.load %[[ALLOC1]][] : memref<i32>
44+
// CHECK: %[[C10_I32:.+]] = arith.constant 10 : i32
45+
// CHECK: %[[CMP:.+]] = arith.cmpi sle, %[[VAR4]], %[[C10_I32]] : i32
46+
// CHECK: %[[EXT:.+]] = arith.extui %[[CMP]] : i1 to i32
47+
// CHECK: %[[C0_I32_3:.+]] = arith.constant 0 : i32
48+
// CHECK: %[[NE:.+]] = arith.cmpi ne, %[[EXT]], %[[C0_I32_3]] : i32
49+
// CHECK: %[[EXT1:.+]] = arith.extui %[[NE]] : i1 to i8
50+
// CHECK: %[[TRUNC:.+]] = arith.trunci %[[EXT1]] : i8 to i1
51+
// CHECK: scf.condition(%[[TRUNC]])
52+
// CHECK: } do {
53+
// CHECK: scf.yield
54+
// CHECK: }
55+
// CHECK: }
56+
// CHECK: %[[LOAD:.+]] = memref.load %[[ALLOC0]][] : memref<i32>
57+
// CHECK: memref.store %[[LOAD]], %[[ALLOC]][] : memref<i32>
58+
// CHECK: %[[RET:.+]] = memref.load %[[ALLOC]][] : memref<i32>
59+
// CHECK: return %[[RET]] : i32
60+
61+
// CHECK: func.func @nestedDoWhile() {
62+
// CHECK: %[[alloca:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
63+
// CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
64+
// CHECK: memref.store %[[C0_I32]], %[[alloca]][] : memref<i32>
65+
// CHECK: memref.alloca_scope {
66+
// CHECK: %[[alloca_0:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
67+
// CHECK: scf.while : () -> () {
68+
// CHECK: %[[ZERO:.+]] = memref.load %[[alloca]][] : memref<i32>
69+
// CHECK: %[[C1_I32:.+]] = arith.constant 1 : i32
70+
// CHECK: %[[ONE:.+]] = arith.addi %[[ZERO]], %[[C1_I32]] : i32
71+
// CHECK: memref.store %[[ONE]], %[[alloca]][] : memref<i32>
72+
// CHECK: %[[C0_I32_1:.+]] = arith.constant 0 : i32
73+
// CHECK: memref.store %[[C0_I32_1]], %[[alloca_0]][] : memref<i32>
74+
// CHECK: memref.alloca_scope {
75+
// CHECK: scf.while : () -> () {
76+
// CHECK: %[[EIGHT:.+]] = memref.load %[[alloca_0]][] : memref<i32>
77+
// CHECK: %[[C2_I32_3:.+]] = arith.constant 2 : i32
78+
// CHECK: %[[NINE:.+]] = arith.cmpi slt, %[[EIGHT]], %[[C2_I32_3]] : i32
79+
// CHECK: %[[TEN:.+]] = arith.extui %9 : i1 to i32
80+
// CHECK: %[[C0_I32_4:.+]] = arith.constant 0 : i32
81+
// CHECK: %[[ELEVEN:.+]] = arith.cmpi ne, %[[TEN]], %[[C0_I32_4]] : i32
82+
// CHECK: %[[TWELVE:.+]] = arith.extui %[[ELEVEN]] : i1 to i8
83+
// CHECK: %[[THIRTEEN:.+]] = arith.trunci %[[TWELVE]] : i8 to i1
84+
// CHECK: scf.condition(%[[THIRTEEN]])
85+
// CHECK: } do {
86+
// CHECK: %[[EIGHT]] = memref.load %[[alloca_0]][] : memref<i32>
87+
// CHECK: %[[C1_I32_3:.+]] = arith.constant 1 : i32
88+
// CHECK: %[[NINE]] = arith.addi %[[EIGHT]], %[[C1_I32_3]] : i32
89+
// CHECK: memref.store %[[NINE]], %[[alloca_0]][] : memref<i32>
90+
// CHECK: scf.yield
91+
// CHECK: }
92+
// CHECK: }
93+
// CHECK: %[[TWO:.+]] = memref.load %[[alloca]][] : memref<i32>
94+
// CHECK: %[[C2_I32:.+]] = arith.constant 2 : i32
95+
// CHECK: %[[THREE:.+]] = arith.cmpi slt, %[[TWO]], %[[C2_I32]] : i32
96+
// CHECK: %[[FOUR:.+]] = arith.extui %[[THREE]] : i1 to i32
97+
// CHECK: %[[C0_I32_2:.+]] = arith.constant 0 : i32
98+
// CHECK: %[[FIVE:.+]] = arith.cmpi ne, %[[FOUR]], %[[C0_I32_2]] : i32
99+
// CHECK: %[[SIX:.+]] = arith.extui %[[FIVE]] : i1 to i8
100+
// CHECK: %[[SEVEN:.+]] = arith.trunci %[[SIX]] : i8 to i1
101+
// CHECK: scf.condition(%[[SEVEN]])
102+
// CHECK: } do {
103+
// CHECK: scf.yield
104+
// CHECK: }
105+
// CHECK: }
106+
// CHECK: return
107+
// CHECK: }

clang/test/CIR/Lowering/ThroughMLIR/while.c

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,25 @@
11
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir %s -o %t.mlir
22
// RUN: FileCheck --input-file=%t.mlir %s
33

4-
void foo() {
4+
void singleWhile() {
55
int a = 0;
66
while(a < 2) {
77
a++;
88
}
99
}
1010

11-
//CHECK: func.func @foo() {
11+
void nestedWhile() {
12+
int a = 0;
13+
while(a < 2) {
14+
int b = 0;
15+
while(b < 2) {
16+
b++;
17+
}
18+
a++;
19+
}
20+
}
21+
22+
//CHECK: func.func @singleWhile() {
1223
//CHECK: %[[alloca:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
1324
//CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
1425
//CHECK: memref.store %[[C0_I32]], %[[alloca]][] : memref<i32>
@@ -32,4 +43,53 @@ void foo() {
3243
//CHECK: }
3344
//CHECK: }
3445
//CHECK: return
46+
//CHECK: }
47+
48+
//CHECK: func.func @nestedWhile() {
49+
//CHECK: %[[alloca:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
50+
//CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
51+
//CHECK: memref.store %[[C0_I32]], %[[alloca]][] : memref<i32>
52+
//CHECK: memref.alloca_scope {
53+
//CHECK: %[[alloca_0:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
54+
//CHECK: scf.while : () -> () {
55+
//CHECK: %[[ZERO:.+]] = memref.load %alloca[] : memref<i32>
56+
//CHECK: %[[C2_I32:.+]] = arith.constant 2 : i32
57+
//CHECK: %[[ONE:.+]] = arith.cmpi slt, %[[ZERO]], %[[C2_I32]] : i32
58+
//CHECK: %[[TWO:.+]] = arith.extui %[[ONE]] : i1 to i32
59+
//CHECK: %[[C0_I32_1:.+]] = arith.constant 0 : i32
60+
//CHECK: %[[THREE:.+]] = arith.cmpi ne, %[[TWO]], %[[C0_I32_1]] : i32
61+
//CHECK: %[[FOUR:.+]] = arith.extui %[[THREE]] : i1 to i8
62+
//CHECK: %[[FIVE:.+]] = arith.trunci %[[FOUR]] : i8 to i1
63+
//CHECK: scf.condition(%[[FIVE]])
64+
//CHECK: } do {
65+
//CHECK: %[[C0_I32_1]] = arith.constant 0 : i32
66+
//CHECK: memref.store %[[C0_I32_1]], %[[alloca_0]][] : memref<i32>
67+
//CHECK: memref.alloca_scope {
68+
//CHECK: scf.while : () -> () {
69+
//CHECK: %[[TWO]] = memref.load %[[alloca_0]][] : memref<i32>
70+
//CHECK: %[[C2_I32]] = arith.constant 2 : i32
71+
//CHECK: %[[THREE]] = arith.cmpi slt, %[[TWO]], %[[C2_I32]] : i32
72+
//CHECK: %[[FOUR]] = arith.extui %[[THREE]] : i1 to i32
73+
//CHECK: %[[C0_I32_2:.+]] = arith.constant 0 : i32
74+
//CHECK: %[[FIVE]] = arith.cmpi ne, %[[FOUR]], %[[C0_I32_2]] : i32
75+
//CHECK: %[[SIX:.+]] = arith.extui %[[FIVE]] : i1 to i8
76+
//CHECK: %[[SEVEN:.+]] = arith.trunci %[[SIX]] : i8 to i1
77+
//CHECK: scf.condition(%[[SEVEN]])
78+
//CHECK: } do {
79+
//CHECK: %[[TWO]] = memref.load %[[alloca_0]][] : memref<i32>
80+
//CHECK: %[[C1_I32_2:.+]] = arith.constant 1 : i32
81+
//CHECK: %[[THREE]] = arith.addi %[[TWO]], %[[C1_I32_2]] : i32
82+
//CHECK: memref.store %[[THREE]], %[[alloca_0]][] : memref<i32>
83+
//CHECK: scf.yield
84+
//CHECK: }
85+
//CHECK: }
86+
//CHECK: %[[ZERO]] = memref.load %[[alloca]][] : memref<i32>
87+
//CHECK: %[[C1_I32:.+]] = arith.constant 1 : i32
88+
//CHECK: %[[ONE]] = arith.addi %[[ZERO]], %[[C1_I32]] : i32
89+
//CHECK: memref.store %[[ONE]], %[[alloca]][] : memref<i32>
90+
//CHECK: scf.yield
91+
//CHECK: }
92+
//CHECK: }
93+
//CHECK: return
94+
//CHECK: }
3595
//CHECK: }

0 commit comments

Comments
 (0)