Skip to content

Commit 92339bf

Browse files
committed
[CIR][Transform] Add ternary simplification
This patch adds a new transformation that transform suitable ternary operations into select operations.
1 parent d3af20c commit 92339bf

File tree

3 files changed

+189
-3
lines changed

3 files changed

+189
-3
lines changed

clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp

Lines changed: 95 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88

99
#include "PassDetail.h"
1010
#include "mlir/Dialect/Func/IR/FuncOps.h"
11+
#include "mlir/IR/Block.h"
12+
#include "mlir/IR/Operation.h"
1113
#include "mlir/IR/PatternMatch.h"
14+
#include "mlir/IR/Region.h"
1215
#include "mlir/Support/LogicalResult.h"
1316
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1417
#include "clang/CIR/Dialect/IR/CIRDialect.h"
1518
#include "clang/CIR/Dialect/Passes.h"
19+
#include "llvm/ADT/SmallVector.h"
1620

1721
using namespace mlir;
1822
using namespace cir;
@@ -107,6 +111,92 @@ struct RemoveTrivialTry : public OpRewritePattern<TryOp> {
107111
}
108112
};
109113

114+
/// Simplify suitable ternary operations into select operations.
115+
///
116+
/// Only those ternary operations that meet the following criteria can be
117+
/// simplified:
118+
/// - The true branch and the false branch cannot have any side effects;
119+
/// - The true branch and the false branch cannot be "too costly" since both of
120+
/// them will be executed after the folding happens.
121+
///
122+
/// For now we only simplify those ternary operations whose true and false
123+
/// branches either directly yield a value or directly yield a constant. That
124+
/// is, both of the two branches of these ternary operation must either:
125+
/// - Only contain a single cir.yield operation, or
126+
/// - Contain a cir.const operation followed by a cir.yield operation that
127+
/// yields the constant value produced by the cir.const operation.
128+
///
129+
/// For example, we will simplify the following ternary operation:
130+
///
131+
/// %0 = cir.ternary (%condition, true {
132+
/// %1 = cir.const ...
133+
/// cir.yield %1
134+
/// } false {
135+
/// cir.yield %2
136+
/// })
137+
///
138+
/// into the following sequence of operations:
139+
///
140+
/// %1 = cir.const ...
141+
/// %0 = cir.select if %condition then %1 else %2
142+
struct SimplifyTernary final : public OpRewritePattern<TernaryOp> {
143+
using OpRewritePattern<TernaryOp>::OpRewritePattern;
144+
145+
LogicalResult matchAndRewrite(TernaryOp op,
146+
PatternRewriter &rewriter) const override {
147+
llvm::SmallVector<mlir::Operation *> opsToHoist;
148+
149+
mlir::Value trueValue =
150+
simplifyTernaryBranch(op.getTrueRegion(), opsToHoist);
151+
if (!trueValue)
152+
return mlir::failure();
153+
154+
mlir::Value falseValue =
155+
simplifyTernaryBranch(op.getFalseRegion(), opsToHoist);
156+
if (!falseValue)
157+
return mlir::failure();
158+
159+
for (auto *hoistOp : opsToHoist)
160+
rewriter.moveOpBefore(hoistOp, op);
161+
rewriter.replaceOpWithNewOp<mlir::cir::SelectOp>(op, op.getCond(),
162+
trueValue, falseValue);
163+
164+
return mlir::success();
165+
}
166+
167+
private:
168+
mlir::Value simplifyTernaryBranch(
169+
mlir::Region &region,
170+
llvm::SmallVector<mlir::Operation *> &opsToHoist) const {
171+
if (!region.hasOneBlock())
172+
return nullptr;
173+
174+
mlir::Block &block = region.front();
175+
176+
// The block can contain at most 2 operations: one cir.const operation
177+
// followed by one cir.yield operation
178+
if (block.getOperations().size() > 2)
179+
return nullptr;
180+
181+
auto yieldOp = mlir::cast<mlir::cir::YieldOp>(block.getTerminator());
182+
auto yieldValue = yieldOp.getArgs()[0];
183+
if (block.getOperations().size() == 1)
184+
return yieldValue;
185+
186+
// The yielded value must be produced by a cir.const operation in the same
187+
// block to make the branch simplifiable.
188+
auto yieldValueDef = mlir::dyn_cast_if_present<mlir::cir::ConstantOp>(
189+
yieldValue.getDefiningOp());
190+
if (!yieldValueDef)
191+
return nullptr;
192+
if (yieldValueDef->getBlock() != &block)
193+
return nullptr;
194+
195+
opsToHoist.push_back(yieldValueDef);
196+
return yieldValue;
197+
}
198+
};
199+
110200
//===----------------------------------------------------------------------===//
111201
// CIRSimplifyPass
112202
//===----------------------------------------------------------------------===//
@@ -131,7 +221,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
131221
RemoveRedundantBranches,
132222
RemoveEmptyScope,
133223
RemoveEmptySwitch,
134-
RemoveTrivialTry
224+
RemoveTrivialTry,
225+
SimplifyTernary
135226
>(patterns.getContext());
136227
// clang-format on
137228
}
@@ -146,8 +237,9 @@ void CIRSimplifyPass::runOnOperation() {
146237
getOperation()->walk([&](Operation *op) {
147238
// CastOp here is to perform a manual `fold` in
148239
// applyOpPatternsAndFold
149-
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp, SelectOp,
150-
ComplexCreateOp, ComplexRealOp, ComplexImagOp>(op))
240+
if (isa<BrOp, BrCondOp, ScopeOp, SwitchOp, CastOp, TryOp, UnaryOp,
241+
TernaryOp, SelectOp, ComplexCreateOp, ComplexRealOp, ComplexImagOp>(
242+
op))
151243
ops.push_back(op);
152244
});
153245

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
// RUN: cir-opt -cir-simplify -o %t.cir %s
2+
// RUN: FileCheck --input-file=%t.cir %s
3+
4+
!s32i = !cir.int<s, 32>
5+
6+
module {
7+
cir.func @fold_ternary(%arg0: !s32i, %arg1: !s32i) -> !s32i {
8+
%0 = cir.const #cir.bool<false> : !cir.bool
9+
%1 = cir.ternary (%0, true {
10+
cir.yield %arg0 : !s32i
11+
}, false {
12+
cir.yield %arg1 : !s32i
13+
}) : (!cir.bool) -> !s32i
14+
cir.return %1 : !s32i
15+
}
16+
17+
// CHECK: cir.func @fold_ternary(%{{.+}}: !s32i, %[[ARG:.+]]: !s32i) -> !s32i {
18+
// CHECK-NEXT: cir.return %[[ARG]] : !s32i
19+
// CHECK-NEXT: }
20+
21+
cir.func @simplify_ternary(%arg0 : !cir.bool, %arg1 : !s32i) -> !s32i {
22+
%0 = cir.ternary (%arg0, true {
23+
%1 = cir.const #cir.int<42> : !s32i
24+
cir.yield %1 : !s32i
25+
}, false {
26+
cir.yield %arg1 : !s32i
27+
}) : (!cir.bool) -> !s32i
28+
cir.return %0 : !s32i
29+
}
30+
31+
// CHECK: cir.func @simplify_ternary(%[[ARG0:.+]]: !cir.bool, %[[ARG1:.+]]: !s32i) -> !s32i {
32+
// CHECK-NEXT: %[[#A:]] = cir.const #cir.int<42> : !s32i
33+
// CHECK-NEXT: %[[#B:]] = cir.select if %[[ARG0]] then %[[#A]] else %[[ARG1]] : (!cir.bool, !s32i, !s32i) -> !s32i
34+
// CHECK-NEXT: cir.return %[[#B]] : !s32i
35+
// CHECK-NEXT: }
36+
37+
cir.func @non_simplifiable_ternary(%arg0 : !cir.bool, %arg1 : !cir.ptr<!s32i>) -> !s32i {
38+
// Not simplifiable, should keep as-is.
39+
%0 = cir.ternary (%arg0, true {
40+
%1 = cir.load %arg1 : !cir.ptr<!s32i>, !s32i
41+
cir.yield %1 : !s32i
42+
}, false {
43+
%2 = cir.const #cir.int<42> : !s32i
44+
cir.yield %2 : !s32i
45+
}) : (!cir.bool) -> !s32i
46+
cir.return %0 : !s32i
47+
}
48+
49+
// CHECK: cir.func @non_simplifiable_ternary(%[[ARG0:.+]]: !cir.bool, %[[ARG1:.+]]: !cir.ptr<!s32i>) -> !s32i {
50+
// CHECK-NEXT: %[[#A:]] = cir.ternary(%[[ARG0]], true {
51+
// CHECK-NEXT: %[[#B:]] = cir.load %[[ARG1]] : !cir.ptr<!s32i>, !s32i
52+
// CHECK-NEXT: cir.yield %[[#B]] : !s32i
53+
// CHECK-NEXT: }, false {
54+
// CHECK-NEXT: %[[#C:]] = cir.const #cir.int<42> : !s32i
55+
// CHECK-NEXT: cir.yield %[[#C]] : !s32i
56+
// CHECK-NEXT: }) : (!cir.bool) -> !s32i
57+
// CHECK-NEXT: cir.return %[[#A]] : !s32i
58+
// CHECK-NEXT: }
59+
}
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir -fclangir-mem2reg -mmlir --mlir-print-ir-before=cir-simplify %s -o %t1.cir 2>&1 | FileCheck -check-prefix=CIR-BEFORE %s
2+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-cir -fclangir-mem2reg -mmlir --mlir-print-ir-after=cir-simplify %s -o %t2.cir 2>&1 | FileCheck -check-prefix=CIR-AFTER %s
3+
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fclangir -emit-llvm %s -o %t.ll
4+
// RUN: FileCheck --input-file=%t.ll --check-prefix=LLVM %s
5+
6+
int test1(bool x) {
7+
return x ? 1 : 2;
8+
}
9+
10+
// CIR-BEFORE: cir.func @_Z5test1b
11+
// CIR-BEFORE: %{{.+}} = cir.ternary(%{{.+}}, true {
12+
// CIR-BEFORE-NEXT: %[[#A:]] = cir.const #cir.int<1> : !s32i
13+
// CIR-BEFORE-NEXT: cir.yield %[[#A]] : !s32i
14+
// CIR-BEFORE-NEXT: }, false {
15+
// CIR-BEFORE-NEXT: %[[#B:]] = cir.const #cir.int<2> : !s32i
16+
// CIR-BEFORE-NEXT: cir.yield %[[#B]] : !s32i
17+
// CIR-BEFORE-NEXT: }) : (!cir.bool) -> !s32i
18+
// CIR-BEFORE: }
19+
20+
// CIR-AFTER: cir.func @_Z5test1b
21+
// CIR-AFTER: %[[#A:]] = cir.const #cir.int<1> : !s32i
22+
// CIR-AFTER-NEXT: %[[#B:]] = cir.const #cir.int<2> : !s32i
23+
// CIR-AFTER-NEXT: %{{.+}} = cir.select if %{{.+}} then %[[#A]] else %[[#B]] : (!cir.bool, !s32i, !s32i) -> !s32i
24+
// CIR-AFTER: }
25+
26+
// LLVM: define dso_local i32 @_Z5test1b
27+
// LLVM: %{{.+}} = select i1 %{{.+}}, i32 1, i32 2
28+
// LLVM: }
29+
30+
// The following test does not work yet because mem2reg does not happen before
31+
// ternary simplify.
32+
33+
// int test2(bool x, int a, int b) {
34+
// return x ? a : b;
35+
// }

0 commit comments

Comments
 (0)