Skip to content

Commit dc06ede

Browse files
committed
[CIR][CodeGen] Basic lowering of increment/decrement
1 parent 2cdbe0d commit dc06ede

File tree

2 files changed

+62
-2
lines changed

2 files changed

+62
-2
lines changed

clang/lib/CIR/CodeGen/LowerToLLVM.cpp

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,37 @@ class CIRFuncLowering : public mlir::OpRewritePattern<mlir::cir::FuncOp> {
206206
}
207207
};
208208

209+
class CIRUnaryOpLowering : public mlir::OpRewritePattern<mlir::cir::UnaryOp> {
210+
public:
211+
using OpRewritePattern<mlir::cir::UnaryOp>::OpRewritePattern;
212+
213+
mlir::LogicalResult
214+
matchAndRewrite(mlir::cir::UnaryOp op,
215+
mlir::PatternRewriter &rewriter) const override {
216+
mlir::Type type = op.getInput().getType();
217+
assert(type.isa<mlir::IntegerType>() && "operand type not supported yet");
218+
219+
switch (op.getKind()) {
220+
case mlir::cir::UnaryOpKind::Inc: {
221+
auto One = rewriter.create<mlir::arith::ConstantOp>(
222+
op.getLoc(), type, mlir::IntegerAttr::get(type, 1));
223+
rewriter.replaceOpWithNewOp<mlir::arith::AddIOp>(op, op.getType(),
224+
op.getInput(), One);
225+
break;
226+
}
227+
case mlir::cir::UnaryOpKind::Dec: {
228+
auto One = rewriter.create<mlir::arith::ConstantOp>(
229+
op.getLoc(), type, mlir::IntegerAttr::get(type, 1));
230+
rewriter.replaceOpWithNewOp<mlir::arith::SubIOp>(op, op.getType(),
231+
op.getInput(), One);
232+
break;
233+
}
234+
}
235+
236+
return mlir::LogicalResult::success();
237+
}
238+
};
239+
209240
class CIRBinOpLowering : public mlir::OpRewritePattern<mlir::cir::BinOp> {
210241
public:
211242
using OpRewritePattern<mlir::cir::BinOp>::OpRewritePattern;
@@ -447,8 +478,8 @@ class CIRBrOpLowering : public mlir::OpRewritePattern<mlir::cir::BrOp> {
447478

448479
void populateCIRToMemRefConversionPatterns(mlir::RewritePatternSet &patterns) {
449480
patterns.add<CIRAllocaLowering, CIRLoadLowering, CIRStoreLowering,
450-
CIRConstantLowering, CIRBinOpLowering, CIRCmpOpLowering,
451-
CIRBrOpLowering>(patterns.getContext());
481+
CIRConstantLowering, CIRUnaryOpLowering, CIRBinOpLowering,
482+
CIRCmpOpLowering, CIRBrOpLowering>(patterns.getContext());
452483
}
453484

454485
void ConvertCIRToLLVMPass::runOnOperation() {
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// RUN: cir-tool %s -cir-to-func -cir-to-memref -o - | FileCheck %s -check-prefix=MLIR
2+
// RUN: cir-tool %s -cir-to-func -cir-to-memref -cir-to-llvm -o - | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM
3+
4+
module {
5+
cir.func @foo() {
6+
%0 = cir.alloca i32, cir.ptr <i32>, ["a", cinit] {alignment = 4 : i64}
7+
%1 = cir.alloca i32, cir.ptr <i32>, ["b", cinit] {alignment = 4 : i64}
8+
%2 = cir.cst(2 : i32) : i32
9+
cir.store %2, %0 : i32, cir.ptr <i32>
10+
cir.store %2, %1 : i32, cir.ptr <i32>
11+
12+
%3 = cir.load %0 : cir.ptr <i32>, i32
13+
%4 = cir.unary(inc, %3) : i32, i32
14+
cir.store %4, %0 : i32, cir.ptr <i32>
15+
16+
%5 = cir.load %1 : cir.ptr <i32>, i32
17+
%6 = cir.unary(dec, %5) : i32, i32
18+
cir.store %6, %1 : i32, cir.ptr <i32>
19+
cir.return
20+
}
21+
}
22+
23+
// MLIR: = arith.constant 1
24+
// MLIR: = arith.addi
25+
// MLIR: = arith.constant 1
26+
// MLIR: = arith.subi
27+
28+
// LLVM: = add i32 %[[#]], 1
29+
// LLVM: = sub i32 %[[#]], 1

0 commit comments

Comments
 (0)