Skip to content

Commit 4fcb3ac

Browse files
authored
[CIR][ThroughMLIR] Support lowering cir.if to scf.if (#640)
This pr introduces CIRIfOpLowering for lowering cir.if to scf.if
1 parent 7c87502 commit 4fcb3ac

File tree

2 files changed

+164
-2
lines changed

2 files changed

+164
-2
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -969,6 +969,31 @@ class CIRYieldOpLowering
969969
}
970970
};
971971

972+
class CIRIfOpLowering : public mlir::OpConversionPattern<mlir::cir::IfOp> {
973+
public:
974+
using mlir::OpConversionPattern<mlir::cir::IfOp>::OpConversionPattern;
975+
976+
mlir::LogicalResult
977+
matchAndRewrite(mlir::cir::IfOp ifop, OpAdaptor adaptor,
978+
mlir::ConversionPatternRewriter &rewriter) const override {
979+
auto condition = adaptor.getCondition();
980+
auto i1Condition = rewriter.create<mlir::arith::TruncIOp>(
981+
ifop->getLoc(), rewriter.getI1Type(), condition);
982+
auto newIfOp = rewriter.create<mlir::scf::IfOp>(
983+
ifop->getLoc(), ifop->getResultTypes(), i1Condition);
984+
auto *thenBlock = rewriter.createBlock(&newIfOp.getThenRegion());
985+
rewriter.inlineBlockBefore(&ifop.getThenRegion().front(), thenBlock,
986+
thenBlock->end());
987+
if (!ifop.getElseRegion().empty()) {
988+
auto *elseBlock = rewriter.createBlock(&newIfOp.getElseRegion());
989+
rewriter.inlineBlockBefore(&ifop.getElseRegion().front(), elseBlock,
990+
elseBlock->end());
991+
}
992+
rewriter.replaceOp(ifop, newIfOp);
993+
return mlir::success();
994+
}
995+
};
996+
972997
class CIRGlobalOpLowering
973998
: public mlir::OpConversionPattern<mlir::cir::GlobalOp> {
974999
public:
@@ -1272,8 +1297,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
12721297
CIRLogOpLowering, CIRRoundOpLowering, CIRPtrStrideOpLowering,
12731298
CIRSinOpLowering, CIRShiftOpLowering, CIRBitClzOpLowering,
12741299
CIRBitCtzOpLowering, CIRBitPopcountOpLowering, CIRBitClrsbOpLowering,
1275-
CIRBitFfsOpLowering, CIRBitParityOpLowering>(converter,
1276-
patterns.getContext());
1300+
CIRBitFfsOpLowering, CIRBitParityOpLowering, CIRIfOpLowering>(
1301+
converter, patterns.getContext());
12771302
}
12781303

12791304
static mlir::TypeConverter prepareTypeConverter() {
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir %s -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
4+
void foo() {
5+
int a = 2;
6+
int b = 0;
7+
if (a > 0) {
8+
b++;
9+
} else {
10+
b--;
11+
}
12+
}
13+
14+
//CHECK: func.func @foo() {
15+
//CHECK: %[[alloca:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
16+
//CHECK: %[[alloca_0:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
17+
//CHECK: %[[C2_I32:.+]] = arith.constant 2 : i32
18+
//CHECK: memref.store %[[C2_I32]], %[[alloca]][] : memref<i32>
19+
//CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
20+
//CHECK: memref.store %[[C0_I32]], %[[alloca_0]][] : memref<i32>
21+
//CHECK: memref.alloca_scope {
22+
//CHECK: %[[ZERO:.+]] = memref.load %[[alloca]][] : memref<i32>
23+
//CHECK: %[[C0_I32_1:.+]] = arith.constant 0 : i32
24+
//CHECK: %[[ONE:.+]] = arith.cmpi ugt, %[[ZERO]], %[[C0_I32_1]] : i32
25+
//CHECK: %[[TWO:.+]] = arith.extui %[[ONE]] : i1 to i32
26+
//CHECK: %[[C0_I32_2:.+]] = arith.constant 0 : i32
27+
//CHECK: %[[THREE:.+]] = arith.cmpi ne, %[[TWO]], %[[C0_I32_2]] : i32
28+
//CHECK: %[[FOUR:.+]] = arith.extui %[[THREE]] : i1 to i8
29+
//CHECK: %[[FIVE:.+]] = arith.trunci %[[FOUR]] : i8 to i1
30+
//CHECK: scf.if %[[FIVE]] {
31+
//CHECK: %[[SIX:.+]] = memref.load %[[alloca_0]][] : memref<i32>
32+
//CHECK: %[[C1_I32:.+]] = arith.constant 1 : i32
33+
//CHECK: %[[SEVEN:.+]] = arith.addi %[[SIX]], %[[C1_I32]] : i32
34+
//CHECK: memref.store %[[SEVEN]], %[[alloca_0]][] : memref<i32>
35+
//CHECK: } else {
36+
//CHECK: %[[SIX:.+]] = memref.load %[[alloca_0]][] : memref<i32>
37+
//CHECK: %[[C1_I32:.+]] = arith.constant 1 : i32
38+
//CHECK: %[[SEVEN:.+]] = arith.subi %[[SIX]], %[[C1_I32]] : i32
39+
//CHECK: memref.store %[[SEVEN]], %[[alloca_0]][] : memref<i32>
40+
//CHECK: }
41+
//CHECK: }
42+
//CHECK: return
43+
//CHECK: }
44+
45+
void foo2() {
46+
int a = 2;
47+
int b = 0;
48+
if (a < 3) {
49+
b++;
50+
}
51+
}
52+
53+
//CHECK: func.func @foo2() {
54+
//CHECK: %[[alloca:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
55+
//CHECK: %[[alloca_0:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
56+
//CHECK: %[[C2_I32:.+]] = arith.constant 2 : i32
57+
//CHECK: memref.store %[[C2_I32]], %[[alloca]][] : memref<i32>
58+
//CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
59+
//CHECK: memref.store %[[C0_I32]], %[[alloca_0]][] : memref<i32>
60+
//CHECK: memref.alloca_scope {
61+
//CHECK: %[[ZERO:.+]] = memref.load %[[alloca]][] : memref<i32>
62+
//CHECK: %[[C3_I32:.+]] = arith.constant 3 : i32
63+
//CHECK: %[[ONE:.+]] = arith.cmpi ult, %[[ZERO]], %[[C3_I32]] : i32
64+
//CHECK: %[[TWO:.+]] = arith.extui %[[ONE]] : i1 to i32
65+
//CHECK: %[[C0_I32_1]] = arith.constant 0 : i32
66+
//CHECK: %[[THREE:.+]] = arith.cmpi ne, %[[TWO]], %[[C0_I32_1]] : i32
67+
//CHECK: %[[FOUR:.+]] = arith.extui %[[THREE]] : i1 to i8
68+
//CHECK: %[[FIVE]] = arith.trunci %[[FOUR]] : i8 to i1
69+
//CHECK: scf.if %[[FIVE]] {
70+
//CHECK: %[[SIX:.+]] = memref.load %[[alloca_0]][] : memref<i32>
71+
//CHECK: %[[C1_I32:.+]] = arith.constant 1 : i32
72+
//CHECK: %[[SEVEN:.+]] = arith.addi %[[SIX]], %[[C1_I32]] : i32
73+
//CHECK: memref.store %[[SEVEN]], %[[alloca_0]][] : memref<i32>
74+
//CHECK: }
75+
//CHECK: }
76+
//CHECK: return
77+
//CHECK: }
78+
79+
void foo3() {
80+
int a = 2;
81+
int b = 0;
82+
if (a < 3) {
83+
int c = 1;
84+
if (c > 2) {
85+
b++;
86+
} else {
87+
b--;
88+
}
89+
}
90+
}
91+
92+
93+
//CHECK: func.func @foo3() {
94+
//CHECK: %[[alloca:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
95+
//CHECK: %[[alloca_0:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
96+
//CHECK: %[[C2_I32:.+]] = arith.constant 2 : i32
97+
//CHECK: memref.store %[[C2_I32]], %[[alloca]][] : memref<i32>
98+
//CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
99+
//CHECK: memref.store %[[C0_I32]], %[[alloca_0]][] : memref<i32>
100+
//CHECK: memref.alloca_scope {
101+
//CHECK: %[[ZERO:.+]] = memref.load %[[alloca]][] : memref<i32>
102+
//CHECK: %[[C3_I32:.+]] = arith.constant 3 : i32
103+
//CHECK: %[[ONE:.+]] = arith.cmpi ult, %[[ZERO]], %[[C3_I32]] : i32
104+
//CHECK: %[[TWO:.+]] = arith.extui %[[ONE]] : i1 to i32
105+
//CHECK: %[[C0_I32_1:.+]] = arith.constant 0 : i32
106+
//CHECK: %[[THREE:.+]] = arith.cmpi ne, %[[TWO:.+]], %[[C0_I32_1]] : i32
107+
//CHECK: %[[FOUR:.+]] = arith.extui %[[THREE]] : i1 to i8
108+
//CHECK: %[[FIVE]] = arith.trunci %[[FOUR]] : i8 to i1
109+
//CHECK: scf.if %[[FIVE]] {
110+
//CHECK: %[[alloca_2:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
111+
//CHECK: %[[C1_I32:.+]] = arith.constant 1 : i32
112+
//CHECK: memref.store %[[C1_I32]], %[[alloca_2]][] : memref<i32>
113+
//CHECK: memref.alloca_scope {
114+
//CHECK: %[[SIX:.+]] = memref.load %[[alloca_2]][] : memref<i32>
115+
//CHECK: %[[C2_I32_3:.+]] = arith.constant 2 : i32
116+
//CHECK: %[[SEVEN:.+]] = arith.cmpi ugt, %[[SIX]], %[[C2_I32_3]] : i32
117+
//CHECK: %[[EIGHT:.+]] = arith.extui %[[SEVEN]] : i1 to i32
118+
//CHECK: %[[C0_I32_4:.+]] = arith.constant 0 : i32
119+
//CHECK: %[[NINE:.+]] = arith.cmpi ne, %[[EIGHT]], %[[C0_I32_4]] : i32
120+
//CHECK: %[[TEN:.+]] = arith.extui %[[NINE]] : i1 to i8
121+
//CHECK: %[[ELEVEN:.+]] = arith.trunci %[[TEN]] : i8 to i1
122+
//CHECK: scf.if %[[ELEVEN]] {
123+
//CHECK: %[[TWELVE:.+]] = memref.load %[[alloca_0]][] : memref<i32>
124+
//CHECK: %[[C1_I32_5:.+]] = arith.constant 1 : i32
125+
//CHECK: %[[THIRTEEN:.+]] = arith.addi %[[TWELVE]], %[[C1_I32_5]] : i32
126+
//CHECK: memref.store %[[THIRTEEN]], %[[alloca_0]][] : memref<i32>
127+
//CHECK: } else {
128+
//CHECK: %[[TWELVE:.+]] = memref.load %[[alloca_0]][] : memref<i32>
129+
//CHECK: %[[C1_I32_5:.+]] = arith.constant 1 : i32
130+
//CHECK: %[[THIRTEEN:.+]] = arith.subi %[[TWELVE]], %[[C1_I32_5]] : i32
131+
//CHECK: memref.store %[[THIRTEEN]], %[[alloca_0]][] : memref<i32>
132+
//CHECK: }
133+
//CHECK: }
134+
//CHECK: }
135+
//CHECK: }
136+
//CHECK: return
137+
//CHECK: }

0 commit comments

Comments
 (0)