Skip to content

Commit 7c87502

Browse files
authored
[CIR][ThroughMLIR] Support lowering cir.condition and cir.while to scf.condition, scf.while (#636)
This pr intruduces CIRConditionLowering and CIRWhileLowering for lowering to scf.
1 parent 6cc4973 commit 7c87502

File tree

3 files changed

+109
-5
lines changed

3 files changed

+109
-5
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "clang/CIR/Dialect/IR/CIRTypes.h"
2525
#include "clang/CIR/LowerToMLIR.h"
2626
#include "clang/CIR/Passes.h"
27+
#include "llvm/ADT/TypeSwitch.h"
2728

2829
using namespace cir;
2930
using namespace llvm;
@@ -55,6 +56,19 @@ class SCFLoop {
5556
int64_t step = 0;
5657
};
5758

59+
class SCFWhileLoop {
60+
public:
61+
SCFWhileLoop(mlir::cir::WhileOp op, mlir::cir::WhileOp::Adaptor adaptor,
62+
mlir::ConversionPatternRewriter *rewriter)
63+
: whileOp(op), adaptor(adaptor), rewriter(rewriter) {}
64+
void transferToSCFWhileOp();
65+
66+
private:
67+
mlir::cir::WhileOp whileOp;
68+
mlir::cir::WhileOp::Adaptor adaptor;
69+
mlir::ConversionPatternRewriter *rewriter;
70+
};
71+
5872
static int64_t getConstant(mlir::cir::ConstantOp op) {
5973
auto attr = op->getAttrs().front().getValue();
6074
const auto IntAttr = attr.dyn_cast<mlir::cir::IntAttr>();
@@ -233,6 +247,20 @@ void SCFLoop::transferToSCFForOp() {
233247
});
234248
}
235249

250+
void SCFWhileLoop::transferToSCFWhileOp() {
251+
auto scfWhileOp = rewriter->create<mlir::scf::WhileOp>(
252+
whileOp->getLoc(), whileOp->getResultTypes(), adaptor.getOperands());
253+
rewriter->createBlock(&scfWhileOp.getBefore());
254+
rewriter->createBlock(&scfWhileOp.getAfter());
255+
256+
rewriter->cloneRegionBefore(whileOp.getCond(),
257+
&scfWhileOp.getBefore().back());
258+
rewriter->eraseBlock(&scfWhileOp.getBefore().back());
259+
260+
rewriter->cloneRegionBefore(whileOp.getBody(), &scfWhileOp.getAfter().back());
261+
rewriter->eraseBlock(&scfWhileOp.getAfter().back());
262+
}
263+
236264
class CIRForOpLowering : public mlir::OpConversionPattern<mlir::cir::ForOp> {
237265
public:
238266
using OpConversionPattern<mlir::cir::ForOp>::OpConversionPattern;
@@ -248,9 +276,46 @@ class CIRForOpLowering : public mlir::OpConversionPattern<mlir::cir::ForOp> {
248276
}
249277
};
250278

279+
class CIRWhileOpLowering
280+
: public mlir::OpConversionPattern<mlir::cir::WhileOp> {
281+
public:
282+
using OpConversionPattern<mlir::cir::WhileOp>::OpConversionPattern;
283+
284+
mlir::LogicalResult
285+
matchAndRewrite(mlir::cir::WhileOp op, OpAdaptor adaptor,
286+
mlir::ConversionPatternRewriter &rewriter) const override {
287+
SCFWhileLoop loop(op, adaptor, &rewriter);
288+
loop.transferToSCFWhileOp();
289+
rewriter.eraseOp(op);
290+
return mlir::success();
291+
}
292+
};
293+
294+
class CIRConditionOpLowering
295+
: public mlir::OpConversionPattern<mlir::cir::ConditionOp> {
296+
public:
297+
using OpConversionPattern<mlir::cir::ConditionOp>::OpConversionPattern;
298+
mlir::LogicalResult
299+
matchAndRewrite(mlir::cir::ConditionOp op, OpAdaptor adaptor,
300+
mlir::ConversionPatternRewriter &rewriter) const override {
301+
auto *parentOp = op->getParentOp();
302+
return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
303+
.Case<mlir::scf::WhileOp>([&](auto) {
304+
auto condition = adaptor.getCondition();
305+
auto i1Condition = rewriter.create<mlir::arith::TruncIOp>(
306+
op.getLoc(), rewriter.getI1Type(), condition);
307+
rewriter.replaceOpWithNewOp<mlir::scf::ConditionOp>(
308+
op, i1Condition, parentOp->getOperands());
309+
return mlir::success();
310+
})
311+
.Default([](auto) { return mlir::failure(); });
312+
}
313+
};
314+
251315
void populateCIRLoopToSCFConversionPatterns(mlir::RewritePatternSet &patterns,
252316
mlir::TypeConverter &converter) {
253-
patterns.add<CIRForOpLowering>(converter, patterns.getContext());
317+
patterns.add<CIRForOpLowering, CIRWhileOpLowering, CIRConditionOpLowering>(
318+
converter, patterns.getContext());
254319
}
255320

256-
} // namespace cir
321+
} // namespace cir

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@
3131
#include "mlir/Dialect/SCF/Transforms/Passes.h"
3232
#include "mlir/IR/BuiltinDialect.h"
3333
#include "mlir/IR/BuiltinTypes.h"
34+
#include "mlir/IR/Operation.h"
35+
#include "mlir/IR/Region.h"
36+
#include "mlir/IR/TypeRange.h"
37+
#include "mlir/IR/ValueRange.h"
3438
#include "mlir/Pass/Pass.h"
3539
#include "mlir/Pass/PassManager.h"
3640
#include "mlir/Support/LogicalResult.h"
@@ -43,7 +47,9 @@
4347
#include "clang/CIR/Dialect/IR/CIRTypes.h"
4448
#include "clang/CIR/LowerToMLIR.h"
4549
#include "clang/CIR/Passes.h"
50+
#include "llvm/ADT/STLExtras.h"
4651
#include "llvm/ADT/Sequence.h"
52+
#include "llvm/ADT/SmallVector.h"
4753
#include "llvm/ADT/TypeSwitch.h"
4854

4955
using namespace cir;
@@ -558,7 +564,6 @@ class CIRFuncOpLowering : public mlir::OpConversionPattern<mlir::cir::FuncOp> {
558564
return mlir::failure();
559565

560566
rewriter.eraseOp(op);
561-
562567
return mlir::LogicalResult::success();
563568
}
564569
};
@@ -883,7 +888,6 @@ class CIRScopeOpLowering
883888
if (mlir::failed(getTypeConverter()->convertTypes(scopeOp->getResultTypes(),
884889
mlirResultTypes)))
885890
return mlir::LogicalResult::failure();
886-
887891
rewriter.setInsertionPoint(scopeOp);
888892
auto newScopeOp = rewriter.create<mlir::memref::AllocaScopeOp>(
889893
scopeOp.getLoc(), mlirResultTypes);
@@ -956,7 +960,7 @@ class CIRYieldOpLowering
956960
mlir::ConversionPatternRewriter &rewriter) const override {
957961
auto *parentOp = op->getParentOp();
958962
return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
959-
.Case<mlir::scf::IfOp, mlir::scf::ForOp>([&](auto) {
963+
.Case<mlir::scf::IfOp, mlir::scf::ForOp, mlir::scf::WhileOp>([&](auto) {
960964
rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(
961965
op, adaptor.getOperands());
962966
return mlir::success();
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -fno-clangir-direct-lowering -emit-mlir %s -o %t.mlir
2+
// RUN: FileCheck --input-file=%t.mlir %s
3+
4+
void foo() {
5+
int a = 0;
6+
while(a < 2) {
7+
a++;
8+
}
9+
}
10+
11+
//CHECK: func.func @foo() {
12+
//CHECK: %[[alloca:.+]] = memref.alloca() {alignment = 4 : i64} : memref<i32>
13+
//CHECK: %[[C0_I32:.+]] = arith.constant 0 : i32
14+
//CHECK: memref.store %[[C0_I32]], %[[alloca]][] : memref<i32>
15+
//CHECK: memref.alloca_scope {
16+
//CHECK: scf.while : () -> () {
17+
//CHECK: %[[ZERO:.+]] = memref.load %[[alloca]][] : memref<i32>
18+
//CHECK: %[[C2_I32:.+]] = arith.constant 2 : i32
19+
//CHECK: %[[ONE:.+]] = arith.cmpi ult, %[[ZERO:.+]], %[[C2_I32]] : i32
20+
//CHECK: %[[TWO:.+]] = arith.extui %[[ONE:.+]] : i1 to i32
21+
//CHECK: %[[C0_I32_0:.+]] = arith.constant 0 : i32
22+
//CHECK: %[[THREE:.+]] = arith.cmpi ne, %[[TWO:.+]], %[[C0_I32_0]] : i32
23+
//CHECK: %[[FOUR:.+]] = arith.extui %[[THREE:.+]] : i1 to i8
24+
//CHECK: %[[FIVE:.+]] = arith.trunci %[[FOUR:.+]] : i8 to i1
25+
//CHECK: scf.condition(%[[FIVE]])
26+
//CHECK: } do {
27+
//CHECK: %[[ZERO:.+]] = memref.load %[[alloca]][] : memref<i32>
28+
//CHECK: %[[C1_I32:.+]] = arith.constant 1 : i32
29+
//CHECK: %[[ONE:.+]] = arith.addi %0, %[[C1_I32:.+]] : i32
30+
//CHECK: memref.store %[[ONE:.+]], %[[alloca]][] : memref<i32>
31+
//CHECK: scf.yield
32+
//CHECK: }
33+
//CHECK: }
34+
//CHECK: return
35+
//CHECK: }

0 commit comments

Comments
 (0)