Skip to content

Commit cea21e5

Browse files
Kureelanza
authored andcommitted
[CIR][Lowering] add cir.ternary to scf.if lowering (#368)
This PR adds `cir.ternary` lowering. There are two approaches to lower `cir.ternary` imo: 1. Use `scf.if` op. 2. Use `cf.cond_br` op. I choose `scf.if` because `scf.if` + canonicalization produces `arith.select` whereas `cf.cond_br` requires scf lifting. In many ways `scf.if` is more high-level and closer to `cir.ternary`. A separate `cir.yield` lowering is required since we cannot directly replace `cir.yield` in the ternary op lowering -- the yield operands may still be illegal and doing so produces `builtin.unrealized_cast` ops. I couldn't figured out a way to solve this issue without adding a separate lowering pattern. Please let me know if you know a way to solve this issue.
1 parent a673837 commit cea21e5

File tree

2 files changed

+98
-3
lines changed

2 files changed

+98
-3
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
#include "clang/CIR/Dialect/IR/CIRTypes.h"
4040
#include "clang/CIR/Passes.h"
4141
#include "llvm/ADT/Sequence.h"
42+
#include "llvm/ADT/TypeSwitch.h"
4243

4344
using namespace cir;
4445
using namespace llvm;
@@ -65,7 +66,8 @@ struct ConvertCIRToMLIRPass
6566
void getDependentDialects(mlir::DialectRegistry &registry) const override {
6667
registry.insert<mlir::BuiltinDialect, mlir::func::FuncDialect,
6768
mlir::affine::AffineDialect, mlir::memref::MemRefDialect,
68-
mlir::arith::ArithDialect, mlir::cf::ControlFlowDialect>();
69+
mlir::arith::ArithDialect, mlir::cf::ControlFlowDialect,
70+
mlir::scf::SCFDialect>();
6971
}
7072
void runOnOperation() final;
7173

@@ -547,15 +549,64 @@ struct CIRBrCondOpLowering
547549
}
548550
};
549551

552+
class CIRTernaryOpLowering
553+
: public mlir::OpConversionPattern<mlir::cir::TernaryOp> {
554+
public:
555+
using OpConversionPattern<mlir::cir::TernaryOp>::OpConversionPattern;
556+
557+
mlir::LogicalResult
558+
matchAndRewrite(mlir::cir::TernaryOp op, OpAdaptor adaptor,
559+
mlir::ConversionPatternRewriter &rewriter) const override {
560+
rewriter.setInsertionPoint(op);
561+
auto condition = adaptor.getCond();
562+
auto i1Condition = rewriter.create<mlir::arith::TruncIOp>(
563+
op.getLoc(), rewriter.getI1Type(), condition);
564+
SmallVector<mlir::Type> resultTypes;
565+
if (mlir::failed(getTypeConverter()->convertTypes(op->getResultTypes(),
566+
resultTypes)))
567+
return mlir::failure();
568+
569+
auto ifOp = rewriter.create<mlir::scf::IfOp>(op.getLoc(), resultTypes,
570+
i1Condition.getResult(), true);
571+
auto *thenBlock = &ifOp.getThenRegion().front();
572+
auto *elseBlock = &ifOp.getElseRegion().front();
573+
rewriter.inlineBlockBefore(&op.getTrueRegion().front(), thenBlock,
574+
thenBlock->end());
575+
rewriter.inlineBlockBefore(&op.getFalseRegion().front(), elseBlock,
576+
elseBlock->end());
577+
578+
rewriter.replaceOp(op, ifOp);
579+
return mlir::success();
580+
}
581+
};
582+
583+
class CIRYieldOpLowering
584+
: public mlir::OpConversionPattern<mlir::cir::YieldOp> {
585+
public:
586+
using OpConversionPattern<mlir::cir::YieldOp>::OpConversionPattern;
587+
mlir::LogicalResult
588+
matchAndRewrite(mlir::cir::YieldOp op, OpAdaptor adaptor,
589+
mlir::ConversionPatternRewriter &rewriter) const override {
590+
auto *parentOp = op->getParentOp();
591+
return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
592+
.Case<mlir::scf::IfOp>([&](auto) {
593+
rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(
594+
op, adaptor.getOperands());
595+
return mlir::success();
596+
})
597+
.Default([](auto) { return mlir::failure(); });
598+
}
599+
};
600+
550601
void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
551602
mlir::TypeConverter &converter) {
552603
patterns.add<CIRReturnLowering, CIRBrOpLowering>(patterns.getContext());
553604

554605
patterns.add<CIRCmpOpLowering, CIRCallLowering, CIRUnaryOpLowering,
555606
CIRBinOpLowering, CIRLoadLowering, CIRConstantLowering,
556607
CIRStoreLowering, CIRAllocaLowering, CIRFuncLowering,
557-
CIRScopeOpLowering, CIRBrCondOpLowering>(converter,
558-
patterns.getContext());
608+
CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering,
609+
CIRYieldOpLowering>(converter, patterns.getContext());
559610
}
560611

561612
static mlir::TypeConverter prepareTypeConverter() {
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
// RUN: cir-opt %s -cir-to-mlir | FileCheck %s -check-prefix=MLIR
2+
// RUN: cir-opt %s -cir-to-mlir --canonicalize | FileCheck %s --check-prefix=MLIR-CANONICALIZE
3+
// RUN: cir-opt %s -cir-to-mlir --canonicalize -cir-mlir-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM
4+
5+
!s32i = !cir.int<s, 32>
6+
7+
module {
8+
cir.func @_Z1xi(%arg0: !s32i) -> !s32i {
9+
%0 = cir.alloca !s32i, cir.ptr <!s32i>, ["y", init] {alignment = 4 : i64}
10+
%1 = cir.alloca !s32i, cir.ptr <!s32i>, ["__retval"] {alignment = 4 : i64}
11+
cir.store %arg0, %0 : !s32i, cir.ptr <!s32i>
12+
%2 = cir.load %0 : cir.ptr <!s32i>, !s32i
13+
%3 = cir.const(#cir.int<0> : !s32i) : !s32i
14+
%4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
15+
%5 = cir.ternary(%4, true {
16+
%7 = cir.const(#cir.int<3> : !s32i) : !s32i
17+
cir.yield %7 : !s32i
18+
}, false {
19+
%7 = cir.const(#cir.int<5> : !s32i) : !s32i
20+
cir.yield %7 : !s32i
21+
}) : (!cir.bool) -> !s32i
22+
cir.store %5, %1 : !s32i, cir.ptr <!s32i>
23+
%6 = cir.load %1 : cir.ptr <!s32i>, !s32i
24+
cir.return %6 : !s32i
25+
}
26+
}
27+
28+
// MLIR: %1 = arith.cmpi ugt, %0, %c0_i32 : i32
29+
// MLIR-NEXT: %2 = arith.extui %1 : i1 to i8
30+
// MLIR-NEXT: %3 = arith.trunci %2 : i8 to i1
31+
// MLIR-NEXT: %4 = scf.if %3 -> (i32) {
32+
// MLIR-NEXT: %c3_i32 = arith.constant 3 : i32
33+
// MLIR-NEXT: scf.yield %c3_i32 : i32
34+
// MLIR-NEXT: } else {
35+
// MLIR-NEXT: %c5_i32 = arith.constant 5 : i32
36+
// MLIR-NEXT: scf.yield %c5_i32 : i32
37+
// MLIR-NEXT: }
38+
// MLIR-NEXT: memref.store %4, %alloca_0[] : memref<i32>
39+
40+
// MLIR-CANONICALIZE: %[[CMP:.*]] = arith.cmpi ugt
41+
// MLIR-CANONICALIZE: arith.select %[[CMP]]
42+
43+
// LLVM: %[[CMP:.*]] = icmp ugt
44+
// LLVM: select i1 %[[CMP]]

0 commit comments

Comments
 (0)