Skip to content

Commit d9c5e68

Browse files
committed
[CIR][Transform] Add ternary simplification
This patch adds a new transformation that transform suitable ternary operations into select operations.
1 parent aa282a9 commit d9c5e68

File tree

5 files changed

+224
-28
lines changed

5 files changed

+224
-28
lines changed

clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp

Lines changed: 88 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88

99
#include "PassDetail.h"
1010
#include "mlir/Dialect/Func/IR/FuncOps.h"
11+
#include "mlir/IR/Block.h"
12+
#include "mlir/IR/Operation.h"
1113
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/IR/Region.h"
1215
#include "mlir/Support/LogicalResult.h"
1316
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1417
#include "clang/CIR/Dialect/IR/CIRDialect.h"
1518
#include "clang/CIR/Dialect/Passes.h"
19+
#include "llvm/ADT/SmallVector.h"
1620

1721
using namespace mlir;
1822
using namespace cir;
@@ -107,6 +111,85 @@ struct RemoveTrivialTry : public OpRewritePattern<TryOp> {
107111
}
108112
};
109113

114+
/// Simplify suitable ternary operations into select operations.
115+
///
116+
/// For now we only simplify those ternary operations whose true and false
117+
/// branches directly yield a value or a constant. That is, both of the true and
118+
/// the false branch must either contain a cir.yield operation as the only
119+
/// operation in the branch, or contain a cir.const operation followed by a
120+
/// cir.yield operation that yields the constant value.
121+
///
122+
/// For example, we will simplify the following ternary operation:
123+
///
124+
/// %0 = cir.ternary (%condition, true {
125+
/// %1 = cir.const ...
126+
/// cir.yield %1
127+
/// } false {
128+
/// cir.yield %2
129+
/// })
130+
///
131+
/// into the following sequence of operations:
132+
///
133+
/// %1 = cir.const ...
134+
/// %0 = cir.select if %condition then %1 else %2
135+
struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
136+
using OpRewritePattern<TernaryOp>::OpRewritePattern;
137+
138+
LogicalResult matchAndRewrite(TernaryOp op,
139+
PatternRewriter &rewriter) const override {
140+
if (op->getNumResults() != 1)
141+
return mlir::failure();
142+
143+
if (!isSimpleTernaryBranch(op.getTrueRegion()) ||
144+
!isSimpleTernaryBranch(op.getFalseRegion()))
145+
return mlir::failure();
146+
147+
mlir::cir::YieldOp trueBranchYieldOp = mlir::cast<mlir::cir::YieldOp>(
148+
op.getTrueRegion().front().getTerminator());
149+
mlir::cir::YieldOp falseBranchYieldOp = mlir::cast<mlir::cir::YieldOp>(
150+
op.getFalseRegion().front().getTerminator());
151+
auto trueValue = trueBranchYieldOp.getArgs()[0];
152+
auto falseValue = falseBranchYieldOp.getArgs()[0];
153+
154+
rewriter.inlineBlockBefore(&op.getTrueRegion().front(), op);
155+
rewriter.inlineBlockBefore(&op.getFalseRegion().front(), op);
156+
trueBranchYieldOp.erase();
157+
falseBranchYieldOp.erase();
158+
rewriter.replaceOpWithNewOp<mlir::cir::SelectOp>(op, op.getCond(),
159+
trueValue, falseValue);
160+
161+
return mlir::success();
162+
}
163+
164+
private:
165+
bool isSimpleTernaryBranch(mlir::Region &region) const {
166+
if (!region.hasOneBlock())
167+
return false;
168+
169+
mlir::Block &onlyBlock = region.front();
170+
auto &ops = onlyBlock.getOperations();
171+
172+
// The region/block could only contain at most 2 operations.
173+
if (ops.size() > 2)
174+
return false;
175+
176+
if (ops.size() == 1) {
177+
// The region/block only contain a cir.yield operation.
178+
return true;
179+
}
180+
181+
// Check whether the region/block contains a cir.const followed by a
182+
// cir.yield that yields the value.
183+
auto yieldOp = mlir::cast<mlir::cir::YieldOp>(onlyBlock.getTerminator());
184+
auto yieldValueDefOp = mlir::dyn_cast_if_present<mlir::cir::ConstantOp>(
185+
yieldOp.getArgs()[0].getDefiningOp());
186+
if (!yieldValueDefOp || yieldValueDefOp->getBlock() != &onlyBlock)
187+
return false;
188+
189+
return true;
190+
}
191+
};
192+
110193
//===----------------------------------------------------------------------===//
111194
// CIRSimplifyPass
112195
//===----------------------------------------------------------------------===//
@@ -131,7 +214,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
131214
RemoveRedundantBranches,
132215
RemoveEmptyScope,
133216
RemoveEmptySwitch,
134-
RemoveTrivialTry
217+
RemoveTrivialTry,
218+
SimplifyTernary
135219
>(patterns.getContext());
136220
// clang-format on
137221
}
@@ -146,8 +230,9 @@ void CIRSimplifyPass::runOnOperation() {
146230
getOperation()->walk([&](Operation *op) {
147231
// CastOp here is to perform a manual `fold` in
148232
// applyOpPatternsAndFold
149-
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp, SelectOp,
150-
ComplexCreateOp, ComplexRealOp, ComplexImagOp>(op))
233+
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp,
234+
TernaryOp, SelectOp, ComplexCreateOp, ComplexRealOp, ComplexImagOp>(
235+
op))
151236
ops.push_back(op);
152237
});
153238

clang/test/CIR/CodeGen/binop.cpp

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,10 @@ void b1(bool a, bool b) {
3232

3333
// CHECK: cir.ternary(%3, true
3434
// CHECK-NEXT: %7 = cir.load %1
35-
// CHECK-NEXT: cir.ternary(%7, true
36-
// CHECK-NEXT: cir.const #true
37-
// CHECK-NEXT: cir.yield
38-
// CHECK-NEXT: false {
39-
// CHECK-NEXT: cir.const #false
40-
// CHECK-NEXT: cir.yield
41-
// CHECK: cir.yield
35+
// CHECK-NEXT: %8 = cir.const #true
36+
// CHECK-NEXT: %9 = cir.const #false
37+
// CHECK-NEXT: %10 = cir.select if %7 then %8 else %9
38+
// CHECK-NEXT: cir.yield %10
4239
// CHECK-NEXT: false {
4340
// CHECK-NEXT: cir.const #false
4441
// CHECK-NEXT: cir.yield
@@ -48,12 +45,10 @@ void b1(bool a, bool b) {
4845
// CHECK-NEXT: cir.yield
4946
// CHECK-NEXT: false {
5047
// CHECK-NEXT: %7 = cir.load %1
51-
// CHECK-NEXT: cir.ternary(%7, true
52-
// CHECK-NEXT: cir.const #true
53-
// CHECK-NEXT: cir.yield
54-
// CHECK-NEXT: false {
55-
// CHECK-NEXT: cir.const #false
56-
// CHECK-NEXT: cir.yield
48+
// CHECK-NEXT: %8 = cir.const #true
49+
// CHECK-NEXT: %9 = cir.const #false
50+
// CHECK-NEXT: %10 = cir.select if %7 then %8 else %9
51+
// CHECK-NEXT: cir.yield %10
5752

5853
void b2(bool a) {
5954
bool x = 0 && a;
@@ -90,7 +85,9 @@ void b3(int a, int b, int c, int d) {
9085
// CHECK-NEXT: %13 = cir.load %2
9186
// CHECK-NEXT: %14 = cir.load %3
9287
// CHECK-NEXT: %15 = cir.cmp(eq, %13, %14)
93-
// CHECK-NEXT: cir.ternary(%15, true
88+
// CHECK-NEXT: %16 = cir.const #true
89+
// CHECK-NEXT: %17 = cir.const #false
90+
// CHECK-NEXT: %18 = cir.select if %15 then %16 else %17
9491
// CHECK: %9 = cir.load %0
9592
// CHECK-NEXT: %10 = cir.load %1
9693
// CHECK-NEXT: %11 = cir.cmp(eq, %9, %10)
@@ -99,7 +96,9 @@ void b3(int a, int b, int c, int d) {
9996
// CHECK-NEXT: %13 = cir.load %2
10097
// CHECK-NEXT: %14 = cir.load %3
10198
// CHECK-NEXT: %15 = cir.cmp(eq, %13, %14)
102-
// CHECK-NEXT: %16 = cir.ternary(%15, true
99+
// CHECK-NEXT: %16 = cir.const #true
100+
// CHECK-NEXT: %17 = cir.const #false
101+
// CHECK-NEXT: %18 = cir.select if %15 then %16 else %17
103102

104103
void testFloatingPointBinOps(float a, float b) {
105104
a * b;

clang/test/CIR/CodeGen/ternary.cpp

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,12 @@ int x(int y) {
1212
// CHECK: %2 = cir.load %0 : !cir.ptr<!s32i>, !s32i
1313
// CHECK: %3 = cir.const #cir.int<0> : !s32i
1414
// CHECK: %4 = cir.cmp(gt, %2, %3) : !s32i, !cir.bool
15-
// CHECK: %5 = cir.ternary(%4, true {
16-
// CHECK: %7 = cir.const #cir.int<3> : !s32i
17-
// CHECK: cir.yield %7 : !s32i
18-
// CHECK: }, false {
19-
// CHECK: %7 = cir.const #cir.int<5> : !s32i
20-
// CHECK: cir.yield %7 : !s32i
21-
// CHECK: }) : (!cir.bool) -> !s32i
22-
// CHECK: cir.store %5, %1 : !s32i, !cir.ptr<!s32i>
23-
// CHECK: %6 = cir.load %1 : !cir.ptr<!s32i>, !s32i
24-
// CHECK: cir.return %6 : !s32i
15+
// CHECK: %5 = cir.const #cir.int<3> : !s32i
16+
// CHECK: %6 = cir.const #cir.int<5> : !s32i
17+
// CHECK: %7 = cir.select if %4 then %5 else %6 : (!cir.bool, !s32i, !s32i) -> !s32i
18+
// CHECK: cir.store %7, %1 : !s32i, !cir.ptr<!s32i>
19+
// CHECK: %8 = cir.load %1 : !cir.ptr<!s32i>, !s32i
20+
// CHECK: cir.return %8 : !s32i
2521
// CHECK: }
2622

2723
typedef enum {
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
// RUN: cir-opt -cir-simplify -o %t.cir %s
2+
// RUN: FileCheck --input-file=%t.cir %s
3+
4+
!s32i = !cir.int<s, 32>
5+
6+
module {
7+
cir.func @fold_ternary(%arg0: !s32i, %arg1: !s32i) -> !s32i {
8+
%0 = cir.const #cir.bool<false> : !cir.bool
9+
%1 = cir.ternary (%0, true {
10+
cir.yield %arg0 : !s32i
11+
}, false {
12+
cir.yield %arg1 : !s32i
13+
}) : (!cir.bool) -> !s32i
14+
cir.return %1 : !s32i
15+
}
16+
17+
// CHECK: cir.func @fold_ternary(%{{.+}}: !s32i, %[[ARG:.+]]: !s32i) -> !s32i {
18+
// CHECK-NEXT: cir.return %[[ARG]] : !s32i
19+
// CHECK-NEXT: }
20+
21+
cir.func @simplify_ternary(%arg0 : !cir.bool, %arg1 : !s32i) -> !s32i {
22+
%0 = cir.ternary (%arg0, true {
23+
%1 = cir.const #cir.int<42> : !s32i
24+
cir.yield %1 : !s32i
25+
}, false {
26+
cir.yield %arg1 : !s32i
27+
}) : (!cir.bool) -> !s32i
28+
cir.return %0 : !s32i
29+
}
30+
31+
// CHECK: cir.func @simplify_ternary(%[[ARG0:.+]]: !cir.bool, %[[ARG1:.+]]: !s32i) -> !s32i {
32+
// CHECK-NEXT: %[[#A:]] = cir.const #cir.int<42> : !s32i
33+
// CHECK-NEXT: %[[#B:]] = cir.select if %[[ARG0]] then %[[#A]] else %[[ARG1]] : (!cir.bool, !s32i, !s32i) -> !s32i
34+
// CHECK-NEXT: cir.return %[[#B]] : !s32i
35+
// CHECK-NEXT: }
36+
37+
cir.func @non_simplifiable_ternary(%arg0 : !cir.bool) -> !s32i {
38+
%0 = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init]
39+
%1 = cir.ternary (%arg0, true {
40+
%2 = cir.const #cir.int<42> : !s32i
41+
cir.yield %2 : !s32i
42+
}, false {
43+
%3 = cir.load %0 : !cir.ptr<!s32i>, !s32i
44+
cir.yield %3 : !s32i
45+
}) : (!cir.bool) -> !s32i
46+
cir.return %1 : !s32i
47+
}
48+
49+
// CHECK: cir.func @non_simplifiable_ternary(%[[ARG0:.+]]: !cir.bool) -> !s32i {
50+
// CHECK-NEXT: %[[#A:]] = cir.alloca !s32i, !cir.ptr<!s32i>, ["a", init]
51+
// CHECK-NEXT: %[[#B:]] = cir.ternary(%[[ARG0]], true {
52+
// CHECK-NEXT: %[[#C:]] = cir.const #cir.int<42> : !s32i
53+
// CHECK-NEXT: cir.yield %[[#C]] : !s32i
54+
// CHECK-NEXT: }, false {
55+
// CHECK-NEXT: %[[#D:]] = cir.load %[[#A]] : !cir.ptr<!s32i>, !s32i
56+
// CHECK-NEXT: cir.yield %[[#D]] : !s32i
57+
// CHECK-NEXT: }) : (!cir.bool) -> !s32i
58+
// CHECK-NEXT: cir.return %[[#B]] : !s32i
59+
// CHECK-NEXT: }
60+
}
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir -mmlir --mlir-print-ir-before=cir-simplify %s -o %t1.cir 2>&1 | FileCheck -check-prefix=CIR-BEFORE %s
2+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir -mmlir --mlir-print-ir-after=cir-simplify %s -o %t2.cir 2>&1 | FileCheck -check-prefix=CIR-AFTER %s
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t.ll
4+
// RUN: FileCheck --input-file=%t.ll --check-prefix=LLVM %s
5+
6+
int test(bool x) {
7+
return x ? 1 : 2;
8+
}
9+
10+
// CIR-BEFORE: cir.func @_Z4testb
11+
// CIR-BEFORE: %{{.+}} = cir.ternary(%{{.+}}, true {
12+
// CIR-BEFORE-NEXT: %[[#A:]] = cir.const #cir.int<1> : !s32i
13+
// CIR-BEFORE-NEXT: cir.yield %[[#A]] : !s32i
14+
// CIR-BEFORE-NEXT: }, false {
15+
// CIR-BEFORE-NEXT: %[[#B:]] = cir.const #cir.int<2> : !s32i
16+
// CIR-BEFORE-NEXT: cir.yield %[[#B]] : !s32i
17+
// CIR-BEFORE-NEXT: }) : (!cir.bool) -> !s32i
18+
// CIR-BEFORE: }
19+
20+
// CIR-AFTER: cir.func @_Z4testb
21+
// CIR-AFTER: %[[#A:]] = cir.const #cir.int<1> : !s32i
22+
// CIR-AFTER-NEXT: %[[#B:]] = cir.const #cir.int<2> : !s32i
23+
// CIR-AFTER-NEXT: %{{.+}} = cir.select if %{{.+}} then %[[#A]] else %[[#B]] : (!cir.bool, !s32i, !s32i) -> !s32i
24+
// CIR-AFTER: }
25+
26+
// LLVM: define dso_local i32 @_Z4testb
27+
// LLVM: %{{.+}} = select i1 %{{.+}}, i32 1, i32 2
28+
// LLVM: }
29+
30+
int test2(bool cond) {
31+
constexpr int x = 1;
32+
constexpr int y = 2;
33+
return cond ? x : y;
34+
}
35+
36+
// CIR-BEFORE: cir.func @_Z5test2b
37+
// CIR-BEFORE: %[[#COND:]] = cir.load %{{.+}} : !cir.ptr<!cir.bool>, !cir.bool
38+
// CIR-BEFORE-NEXT: %{{.+}} = cir.ternary(%[[#COND]], true {
39+
// CIR-BEFORE-NEXT: %[[#A:]] = cir.const #cir.int<1> : !s32i
40+
// CIR-BEFORE-NEXT: cir.yield %[[#A]] : !s32i
41+
// CIR-BEFORE-NEXT: }, false {
42+
// CIR-BEFORE-NEXT: %[[#B:]] = cir.const #cir.int<2> : !s32i
43+
// CIR-BEFORE-NEXT: cir.yield %[[#B]] : !s32i
44+
// CIR-BEFORE-NEXT: }) : (!cir.bool) -> !s32i
45+
// CIR-BEFORE: }
46+
47+
// CIR-AFTER: cir.func @_Z5test2b
48+
// CIR-AFTER: %[[#COND:]] = cir.load %{{.+}} : !cir.ptr<!cir.bool>, !cir.bool
49+
// CIR-AFTER-NEXT: %[[#A:]] = cir.const #cir.int<1> : !s32i
50+
// CIR-AFTER-NEXT: %[[#B:]] = cir.const #cir.int<2> : !s32i
51+
// CIR-AFTER-NEXT: %{{.+}} = cir.select if %[[#COND]] then %[[#A]] else %[[#B]] : (!cir.bool, !s32i, !s32i) -> !s32i
52+
// CIR-AFTER: }
53+
54+
// LLVM: define dso_local i32 @_Z5test2b
55+
// LLVM: %{{.+}} = select i1 %{{.+}}, i32 1, i32 2
56+
// LLVM: }

0 commit comments

Comments
 (0)