Skip to content

Commit 8b2274d

Browse files
authored
[CIR][Transform] Add simplify transformation for select op (#816)
As mentioned at #809 (comment) , this PR adds more simplify transformations for select op: - `cir.select if %0 then x else x` -> `x` - `cir.select if %0 then #true else #false` -> `%0` - `cir.select if %0 then #false else #true` -> `cir.unary not %0`
1 parent 016405c commit 8b2274d

File tree

3 files changed

+89
-6
lines changed

3 files changed

+89
-6
lines changed

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1400,11 +1400,20 @@ void TernaryOp::build(OpBuilder &builder, OperationState &result, Value cond,
14001400

14011401
OpFoldResult SelectOp::fold(FoldAdaptor adaptor) {
14021402
auto condition = adaptor.getCondition();
1403-
if (!condition)
1404-
return nullptr;
1403+
if (condition) {
1404+
auto conditionValue = mlir::cast<mlir::cir::BoolAttr>(condition).getValue();
1405+
return conditionValue ? getTrueValue() : getFalseValue();
1406+
}
14051407

1406-
auto conditionValue = mlir::cast<mlir::cir::BoolAttr>(condition).getValue();
1407-
return conditionValue ? getTrueValue() : getFalseValue();
1408+
// cir.select if %0 then x else x -> x
1409+
auto trueValue = adaptor.getTrueValue();
1410+
auto falseValue = adaptor.getFalseValue();
1411+
if (trueValue && trueValue == falseValue)
1412+
return trueValue;
1413+
if (getTrueValue() == getFalseValue())
1414+
return getTrueValue();
1415+
1416+
return nullptr;
14081417
}
14091418

14101419
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/CIRSimplify.cpp

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,45 @@ struct RemoveTrivialTry : public OpRewritePattern<TryOp> {
107107
}
108108
};
109109

110+
struct SimplifySelect : public OpRewritePattern<SelectOp> {
111+
using OpRewritePattern<SelectOp>::OpRewritePattern;
112+
113+
LogicalResult matchAndRewrite(SelectOp op,
114+
PatternRewriter &rewriter) const final {
115+
mlir::Operation *trueValueOp = op.getTrueValue().getDefiningOp();
116+
mlir::Operation *falseValueOp = op.getFalseValue().getDefiningOp();
117+
auto trueValueConstOp =
118+
mlir::dyn_cast_if_present<mlir::cir::ConstantOp>(trueValueOp);
119+
auto falseValueConstOp =
120+
mlir::dyn_cast_if_present<mlir::cir::ConstantOp>(falseValueOp);
121+
if (!trueValueConstOp || !falseValueConstOp)
122+
return mlir::failure();
123+
124+
auto trueValue =
125+
mlir::dyn_cast<mlir::cir::BoolAttr>(trueValueConstOp.getValue());
126+
auto falseValue =
127+
mlir::dyn_cast<mlir::cir::BoolAttr>(falseValueConstOp.getValue());
128+
if (!trueValue || !falseValue)
129+
return mlir::failure();
130+
131+
// cir.select if %0 then #true else #false -> %0
132+
if (trueValue.getValue() && !falseValue.getValue()) {
133+
rewriter.replaceAllUsesWith(op, op.getCondition());
134+
rewriter.eraseOp(op);
135+
return mlir::success();
136+
}
137+
138+
// cir.seleft if %0 then #false else #true -> cir.unary not %0
139+
if (!trueValue.getValue() && falseValue.getValue()) {
140+
rewriter.replaceOpWithNewOp<mlir::cir::UnaryOp>(
141+
op, mlir::cir::UnaryOpKind::Not, op.getCondition());
142+
return mlir::success();
143+
}
144+
145+
return mlir::failure();
146+
}
147+
};
148+
110149
//===----------------------------------------------------------------------===//
111150
// CIRSimplifyPass
112151
//===----------------------------------------------------------------------===//
@@ -131,7 +170,8 @@ void populateMergeCleanupPatterns(RewritePatternSet &patterns) {
131170
RemoveRedundantBranches,
132171
RemoveEmptyScope,
133172
RemoveEmptySwitch,
134-
RemoveTrivialTry
173+
RemoveTrivialTry,
174+
SimplifySelect
135175
>(patterns.getContext());
136176
// clang-format on
137177
}

clang/test/CIR/Transforms/select.cir

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// RUN: cir-opt --canonicalize -o %t.cir %s
1+
// RUN: cir-opt -cir-simplify -o %t.cir %s
22
// RUN: FileCheck --input-file=%t.cir %s
33

44
!s32i = !cir.int<s, 32>
@@ -23,4 +23,38 @@ module {
2323
// CHECK: cir.func @fold_false(%[[ARG0:.+]]: !s32i, %[[ARG1:.+]]: !s32i) -> !s32i {
2424
// CHECK-NEXT: cir.return %[[ARG1]] : !s32i
2525
// CHECK-NEXT: }
26+
27+
cir.func @fold_to_const(%arg0 : !cir.bool) -> !s32i {
28+
%0 = cir.const #cir.int<42> : !s32i
29+
%1 = cir.select if %arg0 then %0 else %0 : (!cir.bool, !s32i, !s32i) -> !s32i
30+
cir.return %1 : !s32i
31+
}
32+
33+
// CHECK: cir.func @fold_to_const(%{{.+}}: !cir.bool) -> !s32i {
34+
// CHECK-NEXT: %[[#A:]] = cir.const #cir.int<42> : !s32i
35+
// CHECK-NEXT: cir.return %[[#A]] : !s32i
36+
// CHECK-NEXT: }
37+
38+
cir.func @simplify_1(%arg0 : !cir.bool) -> !cir.bool {
39+
%0 = cir.const #cir.bool<true> : !cir.bool
40+
%1 = cir.const #cir.bool<false> : !cir.bool
41+
%2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
42+
cir.return %2 : !cir.bool
43+
}
44+
45+
// CHECK: cir.func @simplify_1(%[[ARG0:.+]]: !cir.bool) -> !cir.bool {
46+
// CHECK-NEXT: cir.return %[[ARG0]] : !cir.bool
47+
// CHECK-NEXT: }
48+
49+
cir.func @simplify_2(%arg0 : !cir.bool) -> !cir.bool {
50+
%0 = cir.const #cir.bool<false> : !cir.bool
51+
%1 = cir.const #cir.bool<true> : !cir.bool
52+
%2 = cir.select if %arg0 then %0 else %1 : (!cir.bool, !cir.bool, !cir.bool) -> !cir.bool
53+
cir.return %2 : !cir.bool
54+
}
55+
56+
// CHECK: cir.func @simplify_2(%[[ARG0:.+]]: !cir.bool) -> !cir.bool {
57+
// CHECK-NEXT: %[[#A:]] = cir.unary(not, %[[ARG0]]) : !cir.bool, !cir.bool
58+
// CHECK-NEXT: cir.return %[[#A]] : !cir.bool
59+
// CHECK-NEXT: }
2660
}

0 commit comments

Comments
 (0)