Skip to content

[CIR][Transform] Add ternary simplification #809

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 87 additions & 2 deletions clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@

#include "PassDetail.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Block.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/Region.h"
#include "mlir/Support/LogicalResult.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "clang/CIR/Dialect/IR/CIRDialect.h"
#include "clang/CIR/Dialect/Passes.h"
#include "llvm/ADT/SmallVector.h"

using namespace mlir;
using namespace cir;
Expand Down Expand Up @@ -107,6 +111,85 @@ struct RemoveTrivialTry : public OpRewritePattern<TryOp> {
}
};

/// Simplify suitable ternary operations into select operations.
///
/// For now we only simplify those ternary operations whose true and false
/// branches directly yield a value or a constant. That is, both of the true and
/// the false branch must either contain a cir.yield operation as the only
/// operation in the branch, or contain a cir.const operation followed by a
/// cir.yield operation that yields the constant value.
///
/// For example, we will simplify the following ternary operation:
///
/// %0 = cir.ternary (%condition, true {
/// %1 = cir.const ...
/// cir.yield %1
/// } false {
/// cir.yield %2
/// })
///
/// into the following sequence of operations:
///
/// %1 = cir.const ...
/// %0 = cir.select if %condition then %1 else %2
struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
using OpRewritePattern<TernaryOp>::OpRewritePattern;

LogicalResult matchAndRewrite(TernaryOp op,
PatternRewriter &rewriter) const override {
if (op->getNumResults() != 1)
return mlir::failure();

if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
!isSimpleTernaryBranch(op.getFalseRegion()))
return mlir::failure();

mlir::cir::YieldOp trueBranchYieldOp = mlir::cast<mlir::cir::YieldOp>(
op.getTrueRegion().front().getTerminator());
mlir::cir::YieldOp falseBranchYieldOp = mlir::cast<mlir::cir::YieldOp>(
op.getFalseRegion().front().getTerminator());
auto trueValue = trueBranchYieldOp.getArgs()[0];
auto falseValue = falseBranchYieldOp.getArgs()[0];

rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
rewriter.eraseOp(trueBranchYieldOp);
rewriter.eraseOp(falseBranchYieldOp);
rewriter.replaceOpWithNewOp<mlir::cir::SelectOp>(op, op.getCond(),
trueValue, falseValue);

return mlir::success();
}

private:
bool isSimpleTernaryBranch(mlir::Region &region) const {
if (!region.hasOneBlock())
return false;

mlir::Block &onlyBlock = region.front();
auto &ops = onlyBlock.getOperations();

// The region/block could only contain at most 2 operations.
if (ops.size() > 2)
return false;

if (ops.size() == 1) {
// The region/block only contain a cir.yield operation.
return true;
}

// Check whether the region/block contains a cir.const followed by a
// cir.yield that yields the value.
auto yieldOp = mlir::cast<mlir::cir::YieldOp>(onlyBlock.getTerminator());
auto yieldValueDefOp = mlir::dyn_cast_if_present<mlir::cir::ConstantOp>(
yieldOp.getArgs()[0].getDefiningOp());
if (!yieldValueDefOp || yieldValueDefOp->getBlock() != &onlyBlock)
return false;

return true;
}
};

struct SimplifySelect : public OpRewritePattern<SelectOp> {
using OpRewritePattern<SelectOp>::OpRewritePattern;

Expand Down Expand Up @@ -171,6 +254,7 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
RemoveEmptyScope,
RemoveEmptySwitch,
RemoveTrivialTry,
SimplifyTernary,
SimplifySelect
>(patterns.getContext());
// clang-format on
Expand All @@ -186,8 +270,9 @@ void CIRSimplifyPass::runOnOperation() {
getOperation()->walk([&](Operation *op) {
// CastOp here is to perform a manual `fold` in
// applyOpPatternsAndFold
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp, SelectOp,
ComplexCreateOp, ComplexRealOp, ComplexImagOp>(op))
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp,
TernaryOp, SelectOp, ComplexCreateOp, ComplexRealOp, ComplexImagOp>(
op))
ops.push_back(op);
});

Expand Down
27 changes: 5 additions & 22 deletions clang/test/CIR/CodeGen/binop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,7 @@ void b1(bool a, bool b) {

// CHECK: cir.ternary(%3, true
// CHECK-NEXT: %7 = cir.load %1
// CHECK-NEXT: cir.ternary(%7, true
// CHECK-NEXT: cir.const #true
// CHECK-NEXT: cir.yield
// CHECK-NEXT: false {
// CHECK-NEXT: cir.const #false
// CHECK-NEXT: cir.yield
// CHECK: cir.yield
// CHECK-NEXT: cir.yield %7
// CHECK-NEXT: false {
// CHECK-NEXT: cir.const #false
// CHECK-NEXT: cir.yield
Expand All @@ -48,11 +42,6 @@ void b1(bool a, bool b) {
// CHECK-NEXT: cir.yield
// CHECK-NEXT: false {
// CHECK-NEXT: %7 = cir.load %1
// CHECK-NEXT: cir.ternary(%7, true
// CHECK-NEXT: cir.const #true
// CHECK-NEXT: cir.yield
// CHECK-NEXT: false {
// CHECK-NEXT: cir.const #false
// CHECK-NEXT: cir.yield

void b2(bool a) {
Expand Down Expand Up @@ -90,16 +79,10 @@ void b3(int a, int b, int c, int d) {
// CHECK-NEXT: %13 = cir.load %2
// CHECK-NEXT: %14 = cir.load %3
// CHECK-NEXT: %15 = cir.cmp(eq, %13, %14)
// CHECK-NEXT: cir.ternary(%15, true
// CHECK: %9 = cir.load %0
// CHECK-NEXT: %10 = cir.load %1
// CHECK-NEXT: %11 = cir.cmp(eq, %9, %10)
// CHECK-NEXT: %12 = cir.ternary(%11, true {
// CHECK: }, false {
// CHECK-NEXT: %13 = cir.load %2
// CHECK-NEXT: %14 = cir.load %3
// CHECK-NEXT: %15 = cir.cmp(eq, %13, %14)
// CHECK-NEXT: %16 = cir.ternary(%15, true
// CHECK-NEXT: cir.yield %15
// CHECK-NEXT: }, false {
// CHECK-NEXT: %13 = cir.const #false
// CHECK-NEXT: cir.yield %13

void testFloatingPointBinOps(float a, float b) {
a * b;
Expand Down
16 changes: 6 additions & 10 deletions clang/test/CIR/CodeGen/ternary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,12 @@ int x(int y) {
// CHECK: %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
// CHECK: %3 = cir.const #cir.int<0> : !s32i
// CHECK: %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
// CHECK: %5 = cir.ternary(%4, true {
// CHECK: %7 = cir.const #cir.int<3> : !s32i
// CHECK: cir.yield %7 : !s32i
// CHECK: }, false {
// CHECK: %7 = cir.const #cir.int<5> : !s32i
// CHECK: cir.yield %7 : !s32i
// CHECK: }) : (!cir.bool) -> !s32i
// CHECK: cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
// CHECK: %6 = cir.load %1 : !cir.ptr<!s32i>, !s32i
// CHECK: cir.return %6 : !s32i
// CHECK: %5 = cir.const #cir.int<3> : !s32i
// CHECK: %6 = cir.const #cir.int<5> : !s32i
// CHECK: %7 = cir.select if %4 then %5 else %6 : (!cir.bool, !s32i, !s32i) -> !s32i
// CHECK: cir.store %7, %1 : !s32i, !cir.ptr<!s32i>
// CHECK: %8 = cir.load %1 : !cir.ptr<!s32i>, !s32i
// CHECK: cir.return %8 : !s32i
// CHECK: }

typedef enum {
Expand Down
60 changes: 60 additions & 0 deletions clang/test/CIR/Transforms/ternary-fold.cir
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// RUN: cir-opt -cir-simplify -o %t.cir %s
// RUN: FileCheck --input-file=%t.cir %s

!s32i = !cir.int<s, 32>

module {
cir.func @fold_ternary(%arg0: !s32i, %arg1: !s32i) -> !s32i {
%0 = cir.const #cir.bool<false> : !cir.bool
%1 = cir.ternary (%0, true {
cir.yield %arg0 : !s32i
}, false {
cir.yield %arg1 : !s32i
}) : (!cir.bool) -> !s32i
cir.return %1 : !s32i
}

// CHECK: cir.func @fold_ternary(%{{.+}}: !s32i, %[[ARG:.+]]: !s32i) -> !s32i {
// CHECK-NEXT: cir.return %[[ARG]] : !s32i
// CHECK-NEXT: }

cir.func @simplify_ternary(%arg0 : !cir.bool, %arg1 : !s32i) -> !s32i {
%0 = cir.ternary (%arg0, true {
%1 = cir.const #cir.int<42> : !s32i
cir.yield %1 : !s32i
}, false {
cir.yield %arg1 : !s32i
}) : (!cir.bool) -> !s32i
cir.return %0 : !s32i
}

// CHECK: cir.func @simplify_ternary(%[[ARG0:.+]]: !cir.bool, %[[ARG1:.+]]: !s32i) -> !s32i {
// CHECK-NEXT: %[[#A:]] = cir.const #cir.int<42> : !s32i
// CHECK-NEXT: %[[#B:]] = cir.select if %[[ARG0]] then %[[#A]] else %[[ARG1]] : (!cir.bool, !s32i, !s32i) -> !s32i
// CHECK-NEXT: cir.return %[[#B]] : !s32i
// CHECK-NEXT: }

cir.func @non_simplifiable_ternary(%arg0 : !cir.bool) -> !s32i {
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init]
%1 = cir.ternary (%arg0, true {
%2 = cir.const #cir.int<42> : !s32i
cir.yield %2 : !s32i
}, false {
%3 = cir.load %0 : !cir.ptr<!s32i>, !s32i
cir.yield %3 : !s32i
}) : (!cir.bool) -> !s32i
cir.return %1 : !s32i
}

// CHECK: cir.func @non_simplifiable_ternary(%[[ARG0:.+]]: !cir.bool) -> !s32i {
// CHECK-NEXT: %[[#A:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init]
// CHECK-NEXT: %[[#B:]] = cir.ternary(%[[ARG0]], true {
// CHECK-NEXT: %[[#C:]] = cir.const #cir.int<42> : !s32i
// CHECK-NEXT: cir.yield %[[#C]] : !s32i
// CHECK-NEXT: }, false {
// CHECK-NEXT: %[[#D:]] = cir.load %[[#A]] : !cir.ptr<!s32i>, !s32i
// CHECK-NEXT: cir.yield %[[#D]] : !s32i
// CHECK-NEXT: }) : (!cir.bool) -> !s32i
// CHECK-NEXT: cir.return %[[#B]] : !s32i
// CHECK-NEXT: }
}
56 changes: 56 additions & 0 deletions clang/test/CIR/Transforms/ternary-fold.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir -mmlir --mlir-print-ir-before=cir-simplify %s -o %t1.cir 2>&1 | FileCheck -check-prefix=CIR-BEFORE %s
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir -mmlir --mlir-print-ir-after=cir-simplify %s -o %t2.cir 2>&1 | FileCheck -check-prefix=CIR-AFTER %s
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t.ll
// RUN: FileCheck --input-file=%t.ll --check-prefix=LLVM %s

int test(bool x) {
return x ? 1 : 2;
}

// CIR-BEFORE: cir.func @_Z4testb
// CIR-BEFORE: %{{.+}} = cir.ternary(%{{.+}}, true {
// CIR-BEFORE-NEXT: %[[#A:]] = cir.const #cir.int<1> : !s32i
// CIR-BEFORE-NEXT: cir.yield %[[#A]] : !s32i
// CIR-BEFORE-NEXT: }, false {
// CIR-BEFORE-NEXT: %[[#B:]] = cir.const #cir.int<2> : !s32i
// CIR-BEFORE-NEXT: cir.yield %[[#B]] : !s32i
// CIR-BEFORE-NEXT: }) : (!cir.bool) -> !s32i
// CIR-BEFORE: }

// CIR-AFTER: cir.func @_Z4testb
// CIR-AFTER: %[[#A:]] = cir.const #cir.int<1> : !s32i
// CIR-AFTER-NEXT: %[[#B:]] = cir.const #cir.int<2> : !s32i
// CIR-AFTER-NEXT: %{{.+}} = cir.select if %{{.+}} then %[[#A]] else %[[#B]] : (!cir.bool, !s32i, !s32i) -> !s32i
// CIR-AFTER: }

// LLVM: define dso_local i32 @_Z4testb
// LLVM: %{{.+}} = select i1 %{{.+}}, i32 1, i32 2
// LLVM: }

int test2(bool cond) {
constexpr int x = 1;
constexpr int y = 2;
return cond ? x : y;
}

// CIR-BEFORE: cir.func @_Z5test2b
// CIR-BEFORE: %[[#COND:]] = cir.load %{{.+}} : !cir.ptr<!cir.bool>, !cir.bool
// CIR-BEFORE-NEXT: %{{.+}} = cir.ternary(%[[#COND]], true {
// CIR-BEFORE-NEXT: %[[#A:]] = cir.const #cir.int<1> : !s32i
// CIR-BEFORE-NEXT: cir.yield %[[#A]] : !s32i
// CIR-BEFORE-NEXT: }, false {
// CIR-BEFORE-NEXT: %[[#B:]] = cir.const #cir.int<2> : !s32i
// CIR-BEFORE-NEXT: cir.yield %[[#B]] : !s32i
// CIR-BEFORE-NEXT: }) : (!cir.bool) -> !s32i
// CIR-BEFORE: }

// CIR-AFTER: cir.func @_Z5test2b
// CIR-AFTER: %[[#COND:]] = cir.load %{{.+}} : !cir.ptr<!cir.bool>, !cir.bool
// CIR-AFTER-NEXT: %[[#A:]] = cir.const #cir.int<1> : !s32i
// CIR-AFTER-NEXT: %[[#B:]] = cir.const #cir.int<2> : !s32i
// CIR-AFTER-NEXT: %{{.+}} = cir.select if %[[#COND]] then %[[#A]] else %[[#B]] : (!cir.bool, !s32i, !s32i) -> !s32i
// CIR-AFTER: }

// LLVM: define dso_local i32 @_Z5test2b
// LLVM: %{{.+}} = select i1 %{{.+}}, i32 1, i32 2
// LLVM: }