diff --git a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp index 8471230c6eab..0853eeb87782 100644 --- a/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp +++ b/clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp @@ -39,6 +39,7 @@ #include "clang/CIR/Dialect/IR/CIRTypes.h" #include "clang/CIR/Passes.h" #include "llvm/ADT/Sequence.h" +#include "llvm/ADT/TypeSwitch.h" using namespace cir; using namespace llvm; @@ -65,7 +66,8 @@ struct ConvertCIRToMLIRPass void getDependentDialects(mlir::DialectRegistry ®istry) const override { registry.insert(); + mlir::arith::ArithDialect, mlir::cf::ControlFlowDialect, + mlir::scf::SCFDialect>(); } void runOnOperation() final; @@ -547,6 +549,55 @@ struct CIRBrCondOpLowering } }; +class CIRTernaryOpLowering + : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(mlir::cir::TernaryOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + rewriter.setInsertionPoint(op); + auto condition = adaptor.getCond(); + auto i1Condition = rewriter.create( + op.getLoc(), rewriter.getI1Type(), condition); + SmallVector resultTypes; + if (mlir::failed(getTypeConverter()->convertTypes(op->getResultTypes(), + resultTypes))) + return mlir::failure(); + + auto ifOp = rewriter.create(op.getLoc(), resultTypes, + i1Condition.getResult(), true); + auto *thenBlock = &ifOp.getThenRegion().front(); + auto *elseBlock = &ifOp.getElseRegion().front(); + rewriter.inlineBlockBefore(&op.getTrueRegion().front(), thenBlock, + thenBlock->end()); + rewriter.inlineBlockBefore(&op.getFalseRegion().front(), elseBlock, + elseBlock->end()); + + rewriter.replaceOp(op, ifOp); + return mlir::success(); + } +}; + +class CIRYieldOpLowering + : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + mlir::LogicalResult + matchAndRewrite(mlir::cir::YieldOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + auto *parentOp = op->getParentOp(); + return llvm::TypeSwitch(parentOp) + .Case([&](auto) { + rewriter.replaceOpWithNewOp( + op, adaptor.getOperands()); + return mlir::success(); + }) + .Default([](auto) { return mlir::failure(); }); + } +}; + void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, mlir::TypeConverter &converter) { patterns.add(patterns.getContext()); @@ -554,8 +605,8 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns, patterns.add(converter, - patterns.getContext()); + CIRScopeOpLowering, CIRBrCondOpLowering, CIRTernaryOpLowering, + CIRYieldOpLowering>(converter, patterns.getContext()); } static mlir::TypeConverter prepareTypeConverter() { diff --git a/clang/test/CIR/Lowering/ThroughMLIR/tenary.cir b/clang/test/CIR/Lowering/ThroughMLIR/tenary.cir new file mode 100644 index 000000000000..df6e6a09a5ff --- /dev/null +++ b/clang/test/CIR/Lowering/ThroughMLIR/tenary.cir @@ -0,0 +1,44 @@ +// RUN: cir-opt %s -cir-to-mlir | FileCheck %s -check-prefix=MLIR +// RUN: cir-opt %s -cir-to-mlir --canonicalize | FileCheck %s --check-prefix=MLIR-CANONICALIZE +// RUN: cir-opt %s -cir-to-mlir --canonicalize -cir-mlir-to-llvm | mlir-translate -mlir-to-llvmir | FileCheck %s -check-prefix=LLVM + +!s32i = !cir.int + +module { +cir.func @_Z1xi(%arg0: !s32i) -> !s32i { + %0 = cir.alloca !s32i, cir.ptr , ["y", init] {alignment = 4 : i64} + %1 = cir.alloca !s32i, cir.ptr , ["__retval"] {alignment = 4 : i64} + cir.store %arg0, %0 : !s32i, cir.ptr + %2 = cir.load %0 : cir.ptr , !s32i + %3 = cir.const(#cir.int<0> : !s32i) : !s32i + %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool + %5 = cir.ternary(%4, true { + %7 = cir.const(#cir.int<3> : !s32i) : !s32i + cir.yield %7 : !s32i + }, false { + %7 = cir.const(#cir.int<5> : !s32i) : !s32i + cir.yield %7 : !s32i + }) : (!cir.bool) -> !s32i + cir.store %5, %1 : !s32i, cir.ptr + %6 = cir.load %1 : cir.ptr , !s32i + cir.return %6 : !s32i + } +} + +// MLIR: %1 = arith.cmpi ugt, %0, %c0_i32 : i32 +// MLIR-NEXT: %2 = arith.extui %1 : i1 to i8 +// MLIR-NEXT: %3 = arith.trunci %2 : i8 to i1 +// MLIR-NEXT: %4 = scf.if %3 -> (i32) { +// MLIR-NEXT: %c3_i32 = arith.constant 3 : i32 +// MLIR-NEXT: scf.yield %c3_i32 : i32 +// MLIR-NEXT: } else { +// MLIR-NEXT: %c5_i32 = arith.constant 5 : i32 +// MLIR-NEXT: scf.yield %c5_i32 : i32 +// MLIR-NEXT: } +// MLIR-NEXT: memref.store %4, %alloca_0[] : memref + +// MLIR-CANONICALIZE: %[[CMP:.*]] = arith.cmpi ugt +// MLIR-CANONICALIZE: arith.select %[[CMP]] + +// LLVM: %[[CMP:.*]] = icmp ugt +// LLVM: select i1 %[[CMP]]