diff --git a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp index a9616ae62bc6..c38bbcde85f7 100644 --- a/clang/lib/CIR/Dialect/IR/CIRDialect.cpp +++ b/clang/lib/CIR/Dialect/IR/CIRDialect.cpp @@ -1395,11 +1395,20 @@ void TernaryOp::build(OpBuilder &builder, OperationState &result, Value cond, OpFoldResult SelectOp::fold(FoldAdaptor adaptor) { auto condition = adaptor.getCondition(); - if (!condition) - return nullptr; + if (condition) { + auto conditionValue = mlir::cast(condition).getValue(); + return conditionValue ? getTrueValue() : getFalseValue(); + } - auto conditionValue = mlir::cast(condition).getValue(); - return conditionValue ? getTrueValue() : getFalseValue(); + // cir.select if %0 then x else x -> x + auto trueValue = adaptor.getTrueValue(); + auto falseValue = adaptor.getFalseValue(); + if (trueValue && trueValue == falseValue) + return trueValue; + if (getTrueValue() == getFalseValue()) + return getTrueValue(); + + return nullptr; } //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp index 1eea92026134..9565b305b564 100644 --- a/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp +++ b/clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp @@ -107,6 +107,45 @@ struct RemoveTrivialTry : public OpRewritePattern { } }; +struct SimplifySelect : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SelectOp op, + PatternRewriter &rewriter) const final { + mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp(); + mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp(); + auto trueValueConstOp = + mlir::dyn_cast_if_present(trueValueOp); + auto falseValueConstOp = + mlir::dyn_cast_if_present(falseValueOp); + if (!trueValueConstOp || !falseValueConstOp) + return mlir::failure(); + + auto trueValue = + mlir::dyn_cast(trueValueConstOp.getValue()); + auto falseValue = + mlir::dyn_cast(falseValueConstOp.getValue()); + if (!trueValue || !falseValue) + return mlir::failure(); + + // cir.select if %0 then #true else #false -> %0 + if (trueValue.getValue() && !falseValue.getValue()) { + rewriter.replaceAllUsesWith(op, op.getCondition()); + rewriter.eraseOp(op); + return mlir::success(); + } + + // cir.seleft if %0 then #false else #true -> cir.unary not %0 + if (!trueValue.getValue() && falseValue.getValue()) { + rewriter.replaceOpWithNewOp( + op, mlir::cir::UnaryOpKind::Not, op.getCondition()); + return mlir::success(); + } + + return mlir::failure(); + } +}; + //===----------------------------------------------------------------------===// // CIRSimplifyPass //===----------------------------------------------------------------------===// @@ -131,7 +170,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) { RemoveRedundantBranches, RemoveEmptyScope, RemoveEmptySwitch, - RemoveTrivialTry + RemoveTrivialTry, + SimplifySelect >(patterns.getContext()); // clang-format on } diff --git a/clang/test/CIR/Transforms/select.cir b/clang/test/CIR/Transforms/select.cir index c3db14daaf4e..6d18be0b9439 100644 --- a/clang/test/CIR/Transforms/select.cir +++ b/clang/test/CIR/Transforms/select.cir @@ -1,4 +1,4 @@ -// RUN: cir-opt --canonicalize -o %t.cir %s +// RUN: cir-opt -cir-simplify -o %t.cir %s // RUN: FileCheck --input-file=%t.cir %s !s32i = !cir.int @@ -23,4 +23,38 @@ module { // CHECK: cir.func @fold_false(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i { // CHECK-NEXT: cir.return %[[ARG1]] : !s32i // CHECK-NEXT: } + + cir.func @fold_to_const(%arg0 : !cir.bool) -> !s32i { + %0 = cir.const #cir.int<42> : !s32i + %1 = cir.select if %arg0 then %0 else %0 : (!cir.bool, !s32i, !s32i) -> !s32i + cir.return %1 : !s32i + } + + // CHECK: cir.func @fold_to_const(%{{.+}}: !cir.bool) -> !s32i { + // CHECK-NEXT: %[[#A:]] = cir.const #cir.int<42> : !s32i + // CHECK-NEXT: cir.return %[[#A]] : !s32i + // CHECK-NEXT: } + + cir.func @simplify_1(%arg0 : !cir.bool) -> !cir.bool { + %0 = cir.const #cir.bool : !cir.bool + %1 = cir.const #cir.bool : !cir.bool + %2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool + cir.return %2 : !cir.bool + } + + // CHECK: cir.func @simplify_1(%[[ARG0:.+]]: !cir.bool) -> !cir.bool { + // CHECK-NEXT: cir.return %[[ARG0]] : !cir.bool + // CHECK-NEXT: } + + cir.func @simplify_2(%arg0 : !cir.bool) -> !cir.bool { + %0 = cir.const #cir.bool : !cir.bool + %1 = cir.const #cir.bool : !cir.bool + %2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool + cir.return %2 : !cir.bool + } + + // CHECK: cir.func @simplify_2(%[[ARG0:.+]]: !cir.bool) -> !cir.bool { + // CHECK-NEXT: %[[#A:]] = cir.unary(not, %[[ARG0]]) : !cir.bool, !cir.bool + // CHECK-NEXT: cir.return %[[#A]] : !cir.bool + // CHECK-NEXT: } }